Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
49
benchmark/hf3fs/bench.sh
Normal file
49
benchmark/hf3fs/bench.sh
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
|
||||||
|
python3 benchmark/hf3fs/bench_storage.py
|
||||||
|
|
||||||
|
####################################################################################################
|
||||||
|
|
||||||
|
rm -rf nohup.out && \
|
||||||
|
nohup python3 -m sglang.launch_server \
|
||||||
|
--model-path /code/models/Qwen3-32B/ \
|
||||||
|
--host 0.0.0.0 --port 33301 \
|
||||||
|
--page-size 64 \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-ratio 2 --hicache-size 0 \
|
||||||
|
--hicache-write-policy write_through \
|
||||||
|
--hicache-storage-backend hf3fs &
|
||||||
|
|
||||||
|
rm -rf bench_multiturn.out && \
|
||||||
|
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||||
|
--model-path /code/models/Qwen3-32B \
|
||||||
|
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--port 33301 \
|
||||||
|
--request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \
|
||||||
|
> bench_multiturn.out &
|
||||||
|
|
||||||
|
####################################################################################################
|
||||||
|
|
||||||
|
rm -rf nohup.out && \
|
||||||
|
nohup python3 -m sglang.launch_server \
|
||||||
|
--model-path /code/models/DeepSeek-R1/ \
|
||||||
|
--tp 16 --nnodes 2 --node-rank 0 \
|
||||||
|
--dist-init-addr 10.74.249.153:5000 \
|
||||||
|
--host 0.0.0.0 --port 33301 \
|
||||||
|
--page-size 64 \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-ratio 2 --hicache-size 60 \
|
||||||
|
--hicache-write-policy write_through \
|
||||||
|
--hicache-storage-backend hf3fs &
|
||||||
|
|
||||||
|
rm -rf bench_multiturn.out && \
|
||||||
|
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||||
|
--model-path /code/models/Qwen3-32B \
|
||||||
|
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--port 33301 \
|
||||||
|
--request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \
|
||||||
|
> bench_multiturn.out &
|
||||||
|
|
||||||
|
####################################################################################################
|
||||||
|
|
||||||
|
ps aux | grep "sglang.launch_server" | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||||
|
ps aux | grep "bench_multiturn.py" | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||||
162
benchmark/hf3fs/bench_client.py
Normal file
162
benchmark/hf3fs/bench_client.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||||
|
|
||||||
|
|
||||||
|
def print_stats(x: List[int]):
|
||||||
|
x = sorted(x)
|
||||||
|
lenx = len(x)
|
||||||
|
print(
|
||||||
|
f"mean = {sum(x)/len(x):.2f}, "
|
||||||
|
f"min = {min(x):.2f}, "
|
||||||
|
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
||||||
|
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
||||||
|
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
||||||
|
f"max = {max(x):.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
# /path/to/hf3fs
|
||||||
|
file_path = "/data/bench.bin"
|
||||||
|
file_size = 1 << 40
|
||||||
|
bytes_per_page = 16 << 20
|
||||||
|
entries = 32
|
||||||
|
file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries)
|
||||||
|
|
||||||
|
print("test batch_read / batch_write")
|
||||||
|
num_pages = 128
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
numel = bytes_per_page // dtype.itemsize
|
||||||
|
offsets = list(range(file_size // bytes_per_page))
|
||||||
|
random.shuffle(offsets)
|
||||||
|
offsets = offsets[:num_pages]
|
||||||
|
offsets = [i * bytes_per_page for i in offsets]
|
||||||
|
tensor_writes = [
|
||||||
|
torch.randn(numel, dtype=dtype)
|
||||||
|
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
||||||
|
]
|
||||||
|
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
|
||||||
|
results = file_ops.batch_write(
|
||||||
|
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
|
||||||
|
)
|
||||||
|
assert all([result == numel * dtype.itemsize for result in results])
|
||||||
|
tensor_reads = [
|
||||||
|
torch.empty(numel, dtype=dtype)
|
||||||
|
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
||||||
|
]
|
||||||
|
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
|
||||||
|
results = file_ops.batch_read(
|
||||||
|
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
|
||||||
|
)
|
||||||
|
assert all([result == numel * dtype.itemsize for result in results])
|
||||||
|
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
|
||||||
|
|
||||||
|
file_ops.close()
|
||||||
|
print("test done")
|
||||||
|
|
||||||
|
|
||||||
|
def bench():
|
||||||
|
file_path = "/data/bench.bin"
|
||||||
|
file_size = 1 << 40
|
||||||
|
bytes_per_page = 16 << 20
|
||||||
|
entries = 8
|
||||||
|
numjobs = 16
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
numel = bytes_per_page // dtype.itemsize
|
||||||
|
|
||||||
|
file_ops = [
|
||||||
|
Hf3fsClient(file_path, file_size, bytes_per_page, entries)
|
||||||
|
for _ in range(numjobs)
|
||||||
|
]
|
||||||
|
|
||||||
|
num_page = entries
|
||||||
|
|
||||||
|
offsets = list(range(file_size // bytes_per_page))
|
||||||
|
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
|
||||||
|
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
|
||||||
|
random.shuffle(offsets)
|
||||||
|
|
||||||
|
warmup = 50
|
||||||
|
iteration = 100
|
||||||
|
|
||||||
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
|
||||||
|
|
||||||
|
w_bw = []
|
||||||
|
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
|
||||||
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
||||||
|
_offsets = [
|
||||||
|
[
|
||||||
|
offset * bytes_per_page
|
||||||
|
for offset in offsets[
|
||||||
|
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
||||||
|
]
|
||||||
|
]
|
||||||
|
for j in range(numjobs)
|
||||||
|
]
|
||||||
|
tik = time.perf_counter()
|
||||||
|
futures = [
|
||||||
|
executor.submit(file_ops[j].batch_write, offset, tensors_write)
|
||||||
|
for j, offset in enumerate(_offsets)
|
||||||
|
]
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
tok = time.perf_counter()
|
||||||
|
if i < warmup:
|
||||||
|
continue
|
||||||
|
w_bw.append(w_size / (tok - tik))
|
||||||
|
results = [
|
||||||
|
_result == bytes_per_page for result in results for _result in result
|
||||||
|
]
|
||||||
|
assert all(results)
|
||||||
|
print_stats(w_bw)
|
||||||
|
|
||||||
|
r_bw = []
|
||||||
|
r_size = w_size
|
||||||
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||||
|
_offsets = [
|
||||||
|
[
|
||||||
|
offset * bytes_per_page
|
||||||
|
for offset in offsets[
|
||||||
|
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
||||||
|
]
|
||||||
|
]
|
||||||
|
for j in range(numjobs)
|
||||||
|
]
|
||||||
|
tik = time.perf_counter()
|
||||||
|
futures = [
|
||||||
|
executor.submit(file_ops[j].batch_read, offset, tensors_read)
|
||||||
|
for j, offset in enumerate(_offsets)
|
||||||
|
]
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
tok = time.perf_counter()
|
||||||
|
if i < warmup:
|
||||||
|
continue
|
||||||
|
r_bw.append(r_size / (tok - tik))
|
||||||
|
results = [
|
||||||
|
_result == bytes_per_page for result in results for _result in result
|
||||||
|
]
|
||||||
|
assert all(results)
|
||||||
|
print_stats(r_bw)
|
||||||
|
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
for _file_ops in file_ops:
|
||||||
|
_file_ops.close()
|
||||||
|
print("bench done")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
test()
|
||||||
|
bench()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
241
benchmark/hf3fs/bench_storage.py
Normal file
241
benchmark/hf3fs/bench_storage.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||||
|
|
||||||
|
|
||||||
|
def print_stats(x: List[int]):
|
||||||
|
x = sorted(x)
|
||||||
|
lenx = len(x)
|
||||||
|
print(
|
||||||
|
f"mean = {sum(x)/len(x):.2f}, "
|
||||||
|
f"min = {min(x):.2f}, "
|
||||||
|
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
||||||
|
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
||||||
|
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
||||||
|
f"max = {max(x):.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
# Qwen3-32B
|
||||||
|
layer_num = 64
|
||||||
|
head_num, head_dim = 8, 128
|
||||||
|
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||||
|
store_dtype = torch.bfloat16
|
||||||
|
tokens_per_page = 64
|
||||||
|
|
||||||
|
file_path_prefix = "/data/test"
|
||||||
|
file_size = 128 << 20
|
||||||
|
numjobs = 16
|
||||||
|
bytes_per_page = 16 << 20
|
||||||
|
entries = 2
|
||||||
|
dtype = store_dtype
|
||||||
|
|
||||||
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||||
|
assert config_path
|
||||||
|
try:
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"file_path_prefix": file_path_prefix,
|
||||||
|
"file_size": file_size,
|
||||||
|
"numjobs": numjobs,
|
||||||
|
"entries": entries,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
||||||
|
|
||||||
|
rank = 0
|
||||||
|
hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype)
|
||||||
|
|
||||||
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
|
assert numel * dtype.itemsize == bytes_per_page
|
||||||
|
|
||||||
|
num_pages = 10
|
||||||
|
tensors = {}
|
||||||
|
for i in range(num_pages):
|
||||||
|
k = f"key_{i}"
|
||||||
|
v = torch.randn((numel,)).to(dtype=dtype)
|
||||||
|
ok = hicache_hf3fs.set(k, v)
|
||||||
|
assert ok, f"Failed to insert {k}"
|
||||||
|
tensors[k] = v
|
||||||
|
assert hicache_hf3fs.get("key_0") is None
|
||||||
|
assert hicache_hf3fs.get("key_1") is None
|
||||||
|
|
||||||
|
start = num_pages - hicache_hf3fs.num_pages
|
||||||
|
for i in range(start, start + hicache_hf3fs.num_pages):
|
||||||
|
k = f"key_{i}"
|
||||||
|
assert hicache_hf3fs.exists(k)
|
||||||
|
out = hicache_hf3fs.get(k)
|
||||||
|
assert out is not None
|
||||||
|
v = tensors[k]
|
||||||
|
assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"
|
||||||
|
|
||||||
|
assert not hicache_hf3fs.exists("not_exists")
|
||||||
|
|
||||||
|
hicache_hf3fs.delete("key_9")
|
||||||
|
v2 = torch.randn((numel,)).to(dtype=dtype)
|
||||||
|
assert hicache_hf3fs.set("key_new", v2)
|
||||||
|
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
|
||||||
|
|
||||||
|
hicache_hf3fs.clear()
|
||||||
|
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages
|
||||||
|
|
||||||
|
# batch
|
||||||
|
num_pages = 10
|
||||||
|
tensors = {}
|
||||||
|
keys = []
|
||||||
|
values = []
|
||||||
|
for i in range(num_pages):
|
||||||
|
k = f"key_{i}"
|
||||||
|
keys.append(k)
|
||||||
|
v = torch.randn((numel,)).to(dtype=dtype)
|
||||||
|
values.append(v)
|
||||||
|
|
||||||
|
ok = hicache_hf3fs.batch_set(keys, values)
|
||||||
|
assert not ok
|
||||||
|
assert hicache_hf3fs.get("key_8") is None
|
||||||
|
assert hicache_hf3fs.get("key_9") is None
|
||||||
|
|
||||||
|
results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
|
||||||
|
for result, key, value in zip(
|
||||||
|
results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
|
||||||
|
):
|
||||||
|
assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"
|
||||||
|
|
||||||
|
hicache_hf3fs.close()
|
||||||
|
os.remove(hicache_hf3fs.file_path)
|
||||||
|
|
||||||
|
print("All test cases passed.")
|
||||||
|
|
||||||
|
|
||||||
|
def bench():
|
||||||
|
# Qwen3-32B
|
||||||
|
layer_num = 64
|
||||||
|
head_num, head_dim = 8, 128
|
||||||
|
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||||
|
store_dtype = torch.bfloat16
|
||||||
|
tokens_per_page = 64
|
||||||
|
|
||||||
|
file_path = "/data/test.bin"
|
||||||
|
file_size = 1 << 40
|
||||||
|
numjobs = 16
|
||||||
|
bytes_per_page = 16 << 20
|
||||||
|
entries = 8
|
||||||
|
dtype = store_dtype
|
||||||
|
hicache_hf3fs = HiCacheHF3FS(
|
||||||
|
file_path=file_path,
|
||||||
|
file_size=file_size,
|
||||||
|
numjobs=numjobs,
|
||||||
|
bytes_per_page=bytes_per_page,
|
||||||
|
entries=entries,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
|
assert numel * dtype.itemsize == bytes_per_page
|
||||||
|
|
||||||
|
num_page = 128
|
||||||
|
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
|
||||||
|
|
||||||
|
warmup = 50
|
||||||
|
iteration = 100
|
||||||
|
|
||||||
|
w_bw = []
|
||||||
|
w_size = num_page * bytes_per_page / (1 << 30)
|
||||||
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
||||||
|
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
|
||||||
|
tik = time.perf_counter()
|
||||||
|
ok = hicache_hf3fs.batch_set(keys, values)
|
||||||
|
tok = time.perf_counter()
|
||||||
|
if i < warmup:
|
||||||
|
continue
|
||||||
|
w_bw.append(w_size / (tok - tik))
|
||||||
|
assert ok
|
||||||
|
print_stats(w_bw)
|
||||||
|
|
||||||
|
r_bw = []
|
||||||
|
r_size = num_page * bytes_per_page / (1 << 30)
|
||||||
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||||
|
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
||||||
|
tik = time.perf_counter()
|
||||||
|
results = hicache_hf3fs.batch_get(keys)
|
||||||
|
tok = time.perf_counter()
|
||||||
|
if i < warmup:
|
||||||
|
continue
|
||||||
|
r_bw.append(r_size / (tok - tik))
|
||||||
|
assert all([r is not None for r in results])
|
||||||
|
print_stats(r_bw)
|
||||||
|
|
||||||
|
hicache_hf3fs.close()
|
||||||
|
|
||||||
|
|
||||||
|
def allclose():
|
||||||
|
# Qwen3-32B
|
||||||
|
layer_num = 64
|
||||||
|
head_num, head_dim = 8, 128
|
||||||
|
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||||
|
store_dtype = torch.bfloat16
|
||||||
|
tokens_per_page = 64
|
||||||
|
|
||||||
|
file_path = "/data/test.bin"
|
||||||
|
file_size = 1 << 40
|
||||||
|
numjobs = 16
|
||||||
|
bytes_per_page = 16 << 20
|
||||||
|
entries = 8
|
||||||
|
dtype = store_dtype
|
||||||
|
hicache_hf3fs = HiCacheHF3FS(
|
||||||
|
file_path=file_path,
|
||||||
|
file_size=file_size,
|
||||||
|
numjobs=numjobs,
|
||||||
|
bytes_per_page=bytes_per_page,
|
||||||
|
entries=entries,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
|
assert numel * dtype.itemsize == bytes_per_page
|
||||||
|
|
||||||
|
num_page = 128
|
||||||
|
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
|
||||||
|
|
||||||
|
iteration = 100
|
||||||
|
|
||||||
|
for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
|
||||||
|
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
|
||||||
|
ok = hicache_hf3fs.batch_set(keys, values)
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
read_keys, read_results = [], []
|
||||||
|
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
|
||||||
|
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
||||||
|
results = hicache_hf3fs.batch_get(keys)
|
||||||
|
read_keys.extend(keys)
|
||||||
|
read_results.extend(results)
|
||||||
|
assert all([r is not None for r in results])
|
||||||
|
|
||||||
|
for key, result in tqdm(zip(read_keys, read_results)):
|
||||||
|
assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)
|
||||||
|
|
||||||
|
hicache_hf3fs.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
test()
|
||||||
|
bench()
|
||||||
|
allclose()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -250,17 +251,33 @@ class HiCacheController:
|
|||||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||||
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
self.prefetch_tp_group = torch.distributed.new_group(
|
||||||
|
group_ranks, backend="gloo"
|
||||||
|
)
|
||||||
|
self.backup_tp_group = torch.distributed.new_group(
|
||||||
|
group_ranks, backend="gloo"
|
||||||
|
)
|
||||||
|
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
self.storage_backend = HiCacheFile()
|
self.storage_backend = HiCacheFile()
|
||||||
self.enable_storage = True
|
elif storage_backend == "hf3fs":
|
||||||
# todo: threshold policy for prefetching
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
|
||||||
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
bytes_per_page = (
|
||||||
|
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||||
|
)
|
||||||
|
dtype = mem_pool_host.dtype
|
||||||
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||||
|
rank, bytes_per_page, dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported storage backend: {storage_backend}"
|
f"Unsupported storage backend: {storage_backend}"
|
||||||
)
|
)
|
||||||
|
self.enable_storage = True
|
||||||
|
# todo: threshold policy for prefetching
|
||||||
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.load_cache_event = load_cache_event
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||||
@@ -522,8 +539,8 @@ class HiCacheController:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||||
for h in operation.hash_value:
|
page_datas = self.storage_backend.batch_get(operation.hash_value)
|
||||||
page_data = self.storage_backend.get(h)
|
for h, page_data in zip(operation.hash_value, page_datas):
|
||||||
if page_data is None:
|
if page_data is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||||
@@ -531,7 +548,9 @@ class HiCacheController:
|
|||||||
break
|
break
|
||||||
if operation.increment(self.page_size):
|
if operation.increment(self.page_size):
|
||||||
self.mem_pool_host.set_from_flat_data_page(
|
self.mem_pool_host.set_from_flat_data_page(
|
||||||
operation.host_indices[operation.completed_tokens],
|
operation.host_indices[
|
||||||
|
operation.completed_tokens - self.page_size
|
||||||
|
],
|
||||||
page_data,
|
page_data,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -583,7 +602,7 @@ class HiCacheController:
|
|||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
storage_hit_count_tensor,
|
storage_hit_count_tensor,
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
group=self.tp_group,
|
group=self.prefetch_tp_group,
|
||||||
)
|
)
|
||||||
storage_hit_count = storage_hit_count_tensor.item()
|
storage_hit_count = storage_hit_count_tensor.item()
|
||||||
|
|
||||||
@@ -635,21 +654,23 @@ class HiCacheController:
|
|||||||
last_hash = operation.last_hash
|
last_hash = operation.last_hash
|
||||||
tokens_to_backup = operation.token_ids
|
tokens_to_backup = operation.token_ids
|
||||||
|
|
||||||
|
last_hashes, data_pages = [], []
|
||||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||||
last_hash = get_hash_str(
|
last_hash = get_hash_str(
|
||||||
tokens_to_backup[i : i + self.page_size], last_hash
|
tokens_to_backup[i : i + self.page_size], last_hash
|
||||||
)
|
)
|
||||||
success = self.storage_backend.set(
|
data_page = self.mem_pool_host.get_flat_data_page(
|
||||||
last_hash,
|
operation.host_indices[i]
|
||||||
self.mem_pool_host.get_flat_data_page(
|
|
||||||
operation.host_indices[i]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if not success:
|
last_hashes.append(last_hash)
|
||||||
logger.warning(f"Failed to write page {last_hash} to storage.")
|
data_pages.append(data_page)
|
||||||
break
|
|
||||||
operation.completed_tokens += self.page_size
|
success = self.storage_backend.batch_set(last_hashes, data_pages)
|
||||||
operation.hash_value.append(last_hash)
|
if not success:
|
||||||
|
logger.warning(f"Failed to write page {last_hashes} to storage.")
|
||||||
|
else:
|
||||||
|
operation.completed_tokens += len(tokens_to_backup)
|
||||||
|
operation.hash_value.extend(last_hashes)
|
||||||
|
|
||||||
min_completed_tokens = operation.completed_tokens
|
min_completed_tokens = operation.completed_tokens
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
@@ -659,7 +680,7 @@ class HiCacheController:
|
|||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
completed_tokens_tensor,
|
completed_tokens_tensor,
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
group=self.tp_group,
|
group=self.backup_tp_group,
|
||||||
)
|
)
|
||||||
min_completed_tokens = completed_tokens_tensor.item()
|
min_completed_tokens = completed_tokens_tensor.item()
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ class HiRadixCache(RadixCache):
|
|||||||
self.write_through_threshold = (
|
self.write_through_threshold = (
|
||||||
1 if hicache_write_policy == "write_through" else 3
|
1 if hicache_write_policy == "write_through" else 3
|
||||||
)
|
)
|
||||||
self.write_through_threshold_storage = 3
|
self.write_through_threshold_storage = (
|
||||||
|
1 if hicache_write_policy == "write_through" else 3
|
||||||
|
)
|
||||||
self.load_back_threshold = 10
|
self.load_back_threshold = 10
|
||||||
super().__init__(
|
super().__init__(
|
||||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||||
@@ -388,10 +390,14 @@ class HiRadixCache(RadixCache):
|
|||||||
self.cache_controller.ack_backup_queue.get()
|
self.cache_controller.ack_backup_queue.get()
|
||||||
)
|
)
|
||||||
host_node = self.ongoing_backup[ack_id]
|
host_node = self.ongoing_backup[ack_id]
|
||||||
if completed_tokens < len(host_node.key):
|
if completed_tokens == 0:
|
||||||
|
host_node.hash_value = None
|
||||||
|
elif completed_tokens < len(host_node.key):
|
||||||
# backup is only partially successful, split the node
|
# backup is only partially successful, split the node
|
||||||
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
||||||
new_node.hash_value = hash_value
|
new_node.hash_value = hash_value
|
||||||
|
else:
|
||||||
|
host_node.hash_value = hash_value
|
||||||
host_node.release_host()
|
host_node.release_host()
|
||||||
del self.ongoing_backup[ack_id]
|
del self.ongoing_backup[ack_id]
|
||||||
|
|
||||||
@@ -431,6 +437,8 @@ class HiRadixCache(RadixCache):
|
|||||||
written_indices,
|
written_indices,
|
||||||
hash_value[:min_completed_tokens],
|
hash_value[:min_completed_tokens],
|
||||||
)
|
)
|
||||||
|
if len(written_indices):
|
||||||
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
||||||
|
|
||||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||||
self.cache_controller.mem_pool_host.free(
|
self.cache_controller.mem_pool_host.free(
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ def synchronized(debug_only=False):
|
|||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
if (not debug_only) or self.debug:
|
if (not debug_only) or self.debug:
|
||||||
return func(self, *args, **kwargs)
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -181,6 +180,15 @@ class HostKVCache(abc.ABC):
|
|||||||
)
|
)
|
||||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||||
|
|
||||||
|
@synchronized(debug_only=True)
|
||||||
|
def update_prefetch(self, indices: torch.Tensor):
|
||||||
|
if not self.is_reserved(indices):
|
||||||
|
raise ValueError(
|
||||||
|
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
||||||
|
f"Current state: {self.get_state(indices)}"
|
||||||
|
)
|
||||||
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
@synchronized(debug_only=True)
|
||||||
def update_synced(self, indices: torch.Tensor):
|
def update_synced(self, indices: torch.Tensor):
|
||||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||||
|
|||||||
65
python/sglang/srt/mem_cache/storage/hf3fs/README.md
Normal file
65
python/sglang/srt/mem_cache/storage/hf3fs/README.md
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# HiCacheHF3FS Setup
|
||||||
|
|
||||||
|
## Build & Package
|
||||||
|
### Source Code
|
||||||
|
https://github.com/deepseek-ai/3FS/blob/main/README.md#check-out-source-code
|
||||||
|
```sh
|
||||||
|
git clone https://github.com/deepseek-ai/3fs
|
||||||
|
|
||||||
|
cd 3fs
|
||||||
|
git submodule update --init --recursive
|
||||||
|
./patches/apply.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build Dev Container
|
||||||
|
https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile
|
||||||
|
```sh
|
||||||
|
cd 3fs/dockerfile
|
||||||
|
docker build -t hf3fs:dev -f dev.dockerfile .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generate Python Wheel
|
||||||
|
```sh
|
||||||
|
docker run -it hf3fs:dev bash
|
||||||
|
|
||||||
|
# Inside the development container
|
||||||
|
git clone https://github.com/deepseek-ai/3fs
|
||||||
|
|
||||||
|
cd 3fs
|
||||||
|
git submodule update --init --recursive
|
||||||
|
./patches/apply.sh
|
||||||
|
|
||||||
|
apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
python3 python3-pip \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||||
|
python3 setup.py bdist_wheel
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
```sh
|
||||||
|
# Install Dependencies
|
||||||
|
# https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile
|
||||||
|
apt update && apt install -y \
|
||||||
|
libaio-dev \
|
||||||
|
libboost-all-dev \
|
||||||
|
libdouble-conversion-dev \
|
||||||
|
libdwarf-dev \
|
||||||
|
libgflags-dev \
|
||||||
|
libgmock-dev \
|
||||||
|
libgoogle-glog-dev \
|
||||||
|
libgoogle-perftools-dev \
|
||||||
|
libgtest-dev \
|
||||||
|
liblz4-dev \
|
||||||
|
liblzma-dev \
|
||||||
|
libssl-dev \
|
||||||
|
libunwind-dev \
|
||||||
|
libuv1-dev
|
||||||
|
|
||||||
|
# Install Python Package
|
||||||
|
pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||||
|
```
|
||||||
177
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
Normal file
177
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
root = Path(__file__).parent.resolve()
|
||||||
|
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hf3fs_fuse.io import (
|
||||||
|
deregister_fd,
|
||||||
|
extract_mount_point,
|
||||||
|
make_ioring,
|
||||||
|
make_iovec,
|
||||||
|
register_fd,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"hf3fs_fuse.io is not available: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def rsynchronized():
|
||||||
|
def _decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
with self.rlock:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
def wsynchronized():
|
||||||
|
def _decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
with self.wlock:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Hf3fsClient:
|
||||||
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
||||||
|
self.path = path
|
||||||
|
self.size = size
|
||||||
|
self.bytes_per_page = bytes_per_page
|
||||||
|
self.entries = entries
|
||||||
|
|
||||||
|
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
||||||
|
os.ftruncate(self.file, size)
|
||||||
|
register_fd(self.file)
|
||||||
|
|
||||||
|
self.hf3fs_mount_point = extract_mount_point(path)
|
||||||
|
self.bs = self.bytes_per_page
|
||||||
|
self.shm_r = multiprocessing.shared_memory.SharedMemory(
|
||||||
|
size=self.bs * self.entries, create=True
|
||||||
|
)
|
||||||
|
self.shm_w = multiprocessing.shared_memory.SharedMemory(
|
||||||
|
size=self.bs * self.entries, create=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
|
||||||
|
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
|
||||||
|
|
||||||
|
self.numa = -1
|
||||||
|
self.ior_r = make_ioring(
|
||||||
|
self.hf3fs_mount_point,
|
||||||
|
self.entries,
|
||||||
|
for_read=True,
|
||||||
|
timeout=1,
|
||||||
|
numa=self.numa,
|
||||||
|
)
|
||||||
|
self.ior_w = make_ioring(
|
||||||
|
self.hf3fs_mount_point,
|
||||||
|
self.entries,
|
||||||
|
for_read=False,
|
||||||
|
timeout=1,
|
||||||
|
numa=self.numa,
|
||||||
|
)
|
||||||
|
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
||||||
|
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
||||||
|
|
||||||
|
self.rlock = threading.RLock()
|
||||||
|
self.wlock = threading.RLock()
|
||||||
|
|
||||||
|
@rsynchronized()
|
||||||
|
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||||
|
self.check(offsets, tensors)
|
||||||
|
|
||||||
|
# prepare
|
||||||
|
current = 0
|
||||||
|
for offset, tensor in zip(offsets, tensors):
|
||||||
|
size = tensor.numel() * tensor.itemsize
|
||||||
|
self.ior_r.prepare(
|
||||||
|
self.iov_r[current : current + size], True, self.file, offset
|
||||||
|
)
|
||||||
|
current += size
|
||||||
|
|
||||||
|
# submit
|
||||||
|
ionum = len(offsets)
|
||||||
|
resv = self.ior_r.submit().wait(min_results=ionum)
|
||||||
|
|
||||||
|
# results
|
||||||
|
hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
|
||||||
|
results = [res.result for res in resv]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@wsynchronized()
|
||||||
|
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||||
|
self.check(offsets, tensors)
|
||||||
|
|
||||||
|
# prepare
|
||||||
|
hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
|
||||||
|
current = 0
|
||||||
|
for offset, tensor in zip(offsets, tensors):
|
||||||
|
size = tensor.numel() * tensor.itemsize
|
||||||
|
self.ior_w.prepare(
|
||||||
|
self.iov_w[current : current + size], False, self.file, offset
|
||||||
|
)
|
||||||
|
current += size
|
||||||
|
|
||||||
|
# submit
|
||||||
|
ionum = len(offsets)
|
||||||
|
resv = self.ior_w.submit().wait(min_results=ionum)
|
||||||
|
|
||||||
|
# results
|
||||||
|
results = [res.result for res in resv]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
||||||
|
sizes = [t.numel() * t.itemsize for t in tensors]
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
len(offsets) > self.entries,
|
||||||
|
len(offsets) != len(sizes),
|
||||||
|
all(
|
||||||
|
[
|
||||||
|
offset < 0 or offset + size > self.size
|
||||||
|
for offset, size in zip(offsets, sizes)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
all([size > self.bytes_per_page for size in sizes]),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
self.close()
|
||||||
|
raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
|
||||||
|
|
||||||
|
def get_size(self) -> int:
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
deregister_fd(self.file)
|
||||||
|
os.close(self.file)
|
||||||
|
del self.ior_r
|
||||||
|
del self.ior_w
|
||||||
|
del self.iov_r
|
||||||
|
del self.iov_w
|
||||||
|
self.shm_r.close()
|
||||||
|
self.shm_w.close()
|
||||||
|
self.shm_r.unlink()
|
||||||
|
self.shm_w.unlink()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
os.fsync(self.file)
|
||||||
35
python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp
Normal file
35
python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
|
||||||
|
py::gil_scoped_release release;
|
||||||
|
char *src_ptr = static_cast<char *>(shm.data_ptr());
|
||||||
|
size_t current = 0;
|
||||||
|
for (size_t i = 0; i < dst.size(); ++i) {
|
||||||
|
auto &t = dst[i];
|
||||||
|
size_t t_bytes = t.numel() * t.element_size();
|
||||||
|
char *dst_ptr = static_cast<char *>(t.data_ptr());
|
||||||
|
std::memcpy(dst_ptr, src_ptr + current, t_bytes);
|
||||||
|
current += t_bytes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
|
||||||
|
py::gil_scoped_release release;
|
||||||
|
char *dst_ptr = static_cast<char *>(shm.data_ptr());
|
||||||
|
size_t current = 0;
|
||||||
|
for (size_t i = 0; i < src.size(); ++i) {
|
||||||
|
auto &t = src[i];
|
||||||
|
size_t t_bytes = t.numel() * t.element_size();
|
||||||
|
char *src_ptr = static_cast<char *>(t.data_ptr());
|
||||||
|
std::memcpy(dst_ptr + current, src_ptr, t_bytes);
|
||||||
|
current += t_bytes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("read_shm", &read_shm, "Read tensors from shared memory");
|
||||||
|
m.def("write_shm", &write_shm, "Write tensors to shared memory");
|
||||||
|
}
|
||||||
278
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
Normal file
278
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
from collections import OrderedDict
|
||||||
|
from functools import wraps
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AtomicCounter:
|
||||||
|
def __init__(self, n: int):
|
||||||
|
assert n > 0
|
||||||
|
self.n = n
|
||||||
|
self._value = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def next(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
current = self._value
|
||||||
|
self._value = (current + 1) % self.n
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def synchronized():
|
||||||
|
def _decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
with self.lock:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
class HiCacheHF3FS(HiCacheStorage):
|
||||||
|
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
file_size: int,
|
||||||
|
numjobs: int,
|
||||||
|
bytes_per_page: int,
|
||||||
|
entries: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self.file_path = file_path
|
||||||
|
self.file_size = file_size
|
||||||
|
self.numjobs = numjobs
|
||||||
|
self.bytes_per_page = bytes_per_page
|
||||||
|
self.entries = entries
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
||||||
|
|
||||||
|
self.num_pages = self.file_size // self.bytes_per_page
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"HiCacheHF3FS "
|
||||||
|
f"file_path = {self.file_path}, "
|
||||||
|
f"file_size = {self.file_size/(2**30):.2f} GB, "
|
||||||
|
f"numjobs = {self.numjobs}, "
|
||||||
|
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
||||||
|
f"entries = {self.entries}, "
|
||||||
|
f"num_pages = {self.num_pages}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ac = AtomicCounter(self.numjobs)
|
||||||
|
self.clients = [
|
||||||
|
Hf3fsClient(
|
||||||
|
self.file_path, self.file_size, self.bytes_per_page, self.entries
|
||||||
|
)
|
||||||
|
for _ in range(numjobs)
|
||||||
|
]
|
||||||
|
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
|
||||||
|
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
||||||
|
# through centralized metadata orchestration.
|
||||||
|
self.lock = threading.RLock()
|
||||||
|
self.free_pages = list(range(self.num_pages))
|
||||||
|
self.key_to_index = OrderedDict()
|
||||||
|
|
||||||
|
atexit.register(self.close)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, lambda sig, frame: self.close())
|
||||||
|
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
||||||
|
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_env_config(
|
||||||
|
rank: int, bytes_per_page: int, dtype: torch.dtype
|
||||||
|
) -> "HiCacheHF3FS":
|
||||||
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||||
|
if not config_path:
|
||||||
|
return HiCacheHF3FS(
|
||||||
|
file_path=f"/data/hicache.{rank}.bin",
|
||||||
|
file_size=1 << 40,
|
||||||
|
numjobs=16,
|
||||||
|
bytes_per_page=bytes_per_page,
|
||||||
|
entries=8,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
||||||
|
|
||||||
|
required_keys = {
|
||||||
|
"file_path_prefix",
|
||||||
|
"file_size",
|
||||||
|
"numjobs",
|
||||||
|
"entries",
|
||||||
|
}
|
||||||
|
missing_keys = required_keys - set(config.keys())
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||||
|
|
||||||
|
return HiCacheHF3FS(
|
||||||
|
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
||||||
|
file_size=int(config["file_size"]),
|
||||||
|
numjobs=int(config["numjobs"]),
|
||||||
|
bytes_per_page=bytes_per_page,
|
||||||
|
entries=int(config["entries"]),
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
return self.batch_get([key], target_location)[0]
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
|
def batch_get(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
target_locations: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> List[torch.Tensor | None]:
|
||||||
|
batch_indices, file_offsets = [], []
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
if key not in self.key_to_index:
|
||||||
|
continue
|
||||||
|
batch_indices.append(i)
|
||||||
|
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
|
||||||
|
self.key_to_index.move_to_end(key)
|
||||||
|
# TODO: target_locations
|
||||||
|
file_results = [
|
||||||
|
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
||||||
|
]
|
||||||
|
|
||||||
|
futures = [
|
||||||
|
self.executor.submit(
|
||||||
|
self.clients[self.ac.next()].batch_read,
|
||||||
|
file_offsets[i : i + self.entries],
|
||||||
|
file_results[i : i + self.entries],
|
||||||
|
)
|
||||||
|
for i in range(0, len(batch_indices), self.entries)
|
||||||
|
]
|
||||||
|
read_results = [result for future in futures for result in future.result()]
|
||||||
|
|
||||||
|
results = [None] * len(keys)
|
||||||
|
for batch_index, file_result, read_result in zip(
|
||||||
|
batch_indices, file_results, read_results
|
||||||
|
):
|
||||||
|
if read_result == self.bytes_per_page:
|
||||||
|
results[batch_index] = file_result
|
||||||
|
else:
|
||||||
|
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||||
|
return self.batch_set([key], [value])
|
||||||
|
|
||||||
|
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||||
|
indices = self.get_batch_set_indices(keys)
|
||||||
|
batch_indices, file_offsets, file_values = [], [], []
|
||||||
|
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
|
||||||
|
if is_written or index == -1:
|
||||||
|
continue
|
||||||
|
batch_indices.append(i)
|
||||||
|
file_offsets.append(index * self.bytes_per_page)
|
||||||
|
file_values.append(value.contiguous())
|
||||||
|
|
||||||
|
futures = [
|
||||||
|
self.executor.submit(
|
||||||
|
self.clients[self.ac.next()].batch_write,
|
||||||
|
file_offsets[i : i + self.entries],
|
||||||
|
file_values[i : i + self.entries],
|
||||||
|
)
|
||||||
|
for i in range(0, len(batch_indices), self.entries)
|
||||||
|
]
|
||||||
|
write_results = [
|
||||||
|
result == self.bytes_per_page
|
||||||
|
for future in futures
|
||||||
|
for result in future.result()
|
||||||
|
]
|
||||||
|
|
||||||
|
results = [index[0] for index in indices]
|
||||||
|
for batch_index, write_result in zip(batch_indices, write_results):
|
||||||
|
key = keys[batch_index]
|
||||||
|
index = indices[batch_index][1]
|
||||||
|
if write_result:
|
||||||
|
self.key_to_index[key] = index
|
||||||
|
self.key_to_index.move_to_end(key)
|
||||||
|
else:
|
||||||
|
logger.error(f"HiCacheHF3FS set {key} failed")
|
||||||
|
self.free_pages.append(index)
|
||||||
|
results[batch_index] = write_result
|
||||||
|
return all(results)
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
|
def get_batch_set_indices(self, keys: List[str]) -> list:
|
||||||
|
ionum = len(keys)
|
||||||
|
# results: tuples of (is_written: bool, page_idx: int)
|
||||||
|
# - is_written: True = hit (no I/O), False = write (miss)
|
||||||
|
# - page_idx: page storing data
|
||||||
|
results = [None] * min(ionum, self.num_pages)
|
||||||
|
if ionum > self.num_pages:
|
||||||
|
results.extend([(False, -1)] * (ionum - self.num_pages))
|
||||||
|
|
||||||
|
new_keys = []
|
||||||
|
for batch_index, key in enumerate(keys[: self.num_pages]):
|
||||||
|
if key in self.key_to_index:
|
||||||
|
results[batch_index] = (True, self.key_to_index[key])
|
||||||
|
self.key_to_index.move_to_end(key)
|
||||||
|
else:
|
||||||
|
new_keys.append((batch_index, key))
|
||||||
|
|
||||||
|
for batch_index, _ in new_keys:
|
||||||
|
index = (
|
||||||
|
self.free_pages.pop()
|
||||||
|
if len(self.free_pages) > 0
|
||||||
|
else self.key_to_index.popitem(last=False)[1]
|
||||||
|
)
|
||||||
|
results[batch_index] = (False, index)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
if key not in self.key_to_index:
|
||||||
|
return
|
||||||
|
index = self.key_to_index.pop(key)
|
||||||
|
self.free_pages.append(index)
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
|
def exists(self, key: str) -> bool:
|
||||||
|
return key in self.key_to_index
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.free_pages = list(range(self.num_pages))
|
||||||
|
self.key_to_index.clear()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
try:
|
||||||
|
for c in self.clients:
|
||||||
|
c.close()
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"close HiCacheHF3FS: {e}")
|
||||||
|
logger.info("close HiCacheHF3FS")
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
import multiprocessing.shared_memory
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
root = Path(__file__).parent.resolve()
|
||||||
|
hf3fs_utils = load(
|
||||||
|
name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rw_shm():
|
||||||
|
numel = 8 << 20
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
page_num = 128
|
||||||
|
page_bytes = numel * dtype.itemsize
|
||||||
|
shm = multiprocessing.shared_memory.SharedMemory(
|
||||||
|
size=page_num * page_bytes, create=True
|
||||||
|
)
|
||||||
|
tshm = torch.frombuffer(shm.buf, dtype=torch.uint8)
|
||||||
|
a = [
|
||||||
|
torch.randn(numel, dtype=dtype)
|
||||||
|
for _ in tqdm(range(page_num), desc="prepare input")
|
||||||
|
]
|
||||||
|
b = [
|
||||||
|
torch.empty(numel, dtype=dtype)
|
||||||
|
for _ in tqdm(range(page_num), desc="prepare output")
|
||||||
|
]
|
||||||
|
hf3fs_utils.write_shm(a, tshm)
|
||||||
|
hf3fs_utils.read_shm(tshm, b)
|
||||||
|
for _a, _b in tqdm(zip(a, b), desc="assert_close"):
|
||||||
|
torch.testing.assert_close(_a, _b)
|
||||||
|
|
||||||
|
del tshm
|
||||||
|
shm.close()
|
||||||
|
shm.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
@@ -1476,7 +1476,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hicache-storage-backend",
|
"--hicache-storage-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["file"], # todo, mooncake
|
choices=["file", "hf3fs"], # todo, mooncake
|
||||||
default=ServerArgs.hicache_storage_backend,
|
default=ServerArgs.hicache_storage_backend,
|
||||||
help="The storage backend for hierarchical KV cache.",
|
help="The storage backend for hierarchical KV cache.",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user