3fs zerocopy (#9109)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -1,6 +1,16 @@
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
python3 benchmark/hf3fs/bench_client.py
|
||||
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
|
||||
python3 benchmark/hf3fs/bench_storage.py
|
||||
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
|
||||
echo '{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}' > \
|
||||
${SGLANG_HICACHE_HF3FS_CONFIG_PATH}
|
||||
python3 benchmark/hf3fs/bench_zerocopy.py
|
||||
|
||||
####################################################################################################
|
||||
|
||||
rm -rf nohup.out && \
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||
Hf3fsLocalMetadataClient,
|
||||
)
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
|
||||
@@ -67,12 +70,15 @@ def test():
|
||||
k = f"key_{i}"
|
||||
v = torch.randn((numel,)).to(dtype=dtype)
|
||||
ok = hicache_hf3fs.set(k, v)
|
||||
assert ok, f"Failed to insert {k}"
|
||||
if i < (file_size // bytes_per_page):
|
||||
assert ok, f"Failed to insert {k}"
|
||||
else:
|
||||
assert not ok
|
||||
tensors[k] = v
|
||||
assert hicache_hf3fs.get("key_0") is None
|
||||
assert hicache_hf3fs.get("key_1") is None
|
||||
assert hicache_hf3fs.get("key_8") is None
|
||||
assert hicache_hf3fs.get("key_9") is None
|
||||
|
||||
start = num_pages - hicache_hf3fs.num_pages
|
||||
start = 0
|
||||
for i in range(start, start + hicache_hf3fs.num_pages):
|
||||
k = f"key_{i}"
|
||||
assert hicache_hf3fs.exists(k)
|
||||
@@ -83,13 +89,16 @@ def test():
|
||||
|
||||
assert not hicache_hf3fs.exists("not_exists")
|
||||
|
||||
hicache_hf3fs.delete("key_9")
|
||||
hicache_hf3fs.delete("key_7")
|
||||
v2 = torch.randn((numel,)).to(dtype=dtype)
|
||||
assert hicache_hf3fs.set("key_new", v2)
|
||||
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
|
||||
|
||||
hicache_hf3fs.clear()
|
||||
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages
|
||||
assert (
|
||||
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
|
||||
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
|
||||
)
|
||||
|
||||
# batch
|
||||
num_pages = 10
|
||||
@@ -134,12 +143,14 @@ def bench():
|
||||
entries = 8
|
||||
dtype = store_dtype
|
||||
hicache_hf3fs = HiCacheHF3FS(
|
||||
rank=0,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
numjobs=numjobs,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=entries,
|
||||
dtype=dtype,
|
||||
metadata_client=Hf3fsLocalMetadataClient(),
|
||||
)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
@@ -167,7 +178,10 @@ def bench():
|
||||
r_bw = []
|
||||
r_size = num_page * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
||||
keys = random.sample(
|
||||
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
|
||||
num_page,
|
||||
)
|
||||
tik = time.perf_counter()
|
||||
results = hicache_hf3fs.batch_get(keys)
|
||||
tok = time.perf_counter()
|
||||
@@ -195,12 +209,14 @@ def allclose():
|
||||
entries = 8
|
||||
dtype = store_dtype
|
||||
hicache_hf3fs = HiCacheHF3FS(
|
||||
rank=0,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
numjobs=numjobs,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=entries,
|
||||
dtype=dtype,
|
||||
metadata_client=Hf3fsLocalMetadataClient(),
|
||||
)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
@@ -218,7 +234,10 @@ def allclose():
|
||||
|
||||
read_keys, read_results = [], []
|
||||
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
|
||||
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
||||
keys = random.sample(
|
||||
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
|
||||
num_page,
|
||||
)
|
||||
results = hicache_hf3fs.batch_get(keys)
|
||||
read_keys.extend(keys)
|
||||
read_results.extend(results)
|
||||
|
||||
140
benchmark/hf3fs/bench_zerocopy.py
Normal file
140
benchmark/hf3fs/bench_zerocopy.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_world_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.managers.cache_controller import (
|
||||
HiCacheController,
|
||||
PrefetchOperation,
|
||||
StorageOperation,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
group = get_world_group().cpu_group
|
||||
|
||||
max_total_num_tokens = 524288
|
||||
page_size = 64
|
||||
kv_cache_dtype = torch.bfloat16
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
device = "cuda"
|
||||
hicache_ratio = 2
|
||||
hicache_size = 0
|
||||
hicache_mem_layout = "page_first"
|
||||
# hicache_mem_layout = "layer_first"
|
||||
hicache_write_policy = "write_through"
|
||||
hicache_io_backend = "kernel"
|
||||
hicache_storage_backend = "hf3fs"
|
||||
prefetch_threshold = 256
|
||||
|
||||
op_size = 1024
|
||||
op_num = 16
|
||||
|
||||
token_to_kv_pool = MHATokenToKVPool(
|
||||
max_total_num_tokens,
|
||||
page_size=page_size,
|
||||
dtype=kv_cache_dtype,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
layer_num=layer_num,
|
||||
device=device,
|
||||
enable_memory_saver=True,
|
||||
)
|
||||
|
||||
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
max_total_num_tokens,
|
||||
dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
kvcache=token_to_kv_pool,
|
||||
need_sort=False,
|
||||
)
|
||||
|
||||
kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
kv_cache,
|
||||
hicache_ratio,
|
||||
hicache_size,
|
||||
page_size,
|
||||
hicache_mem_layout,
|
||||
)
|
||||
|
||||
load_cache_event = threading.Event()
|
||||
cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator,
|
||||
token_to_kv_pool_host,
|
||||
page_size,
|
||||
group,
|
||||
load_cache_event=load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
io_backend=hicache_io_backend,
|
||||
storage_backend=hicache_storage_backend,
|
||||
prefetch_threshold=prefetch_threshold,
|
||||
)
|
||||
|
||||
operations = [
|
||||
StorageOperation(
|
||||
torch.tensor(list(range(i, i + op_size))),
|
||||
list(range(i, i + op_size)),
|
||||
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
|
||||
)
|
||||
for i in tqdm(range(0, op_num * op_size, op_size))
|
||||
]
|
||||
|
||||
tik = time.monotonic()
|
||||
if hicache_mem_layout == "page_first":
|
||||
for operation in operations:
|
||||
cache_controller.zerocopy_page_backup(operation, batch_size=128)
|
||||
elif hicache_mem_layout == "layer_first":
|
||||
for operation in operations:
|
||||
cache_controller.generic_page_backup(operation, batch_size=128)
|
||||
tok = time.monotonic()
|
||||
print(f"{tok-tik:.6f} s")
|
||||
|
||||
operations = [
|
||||
PrefetchOperation(
|
||||
f"{i}",
|
||||
torch.tensor(list(range(i, i + op_size))),
|
||||
list(range(i, i + op_size)),
|
||||
f"{i}",
|
||||
)
|
||||
for i in tqdm(range(0, op_num * op_size, op_size))
|
||||
]
|
||||
|
||||
for operation in operations:
|
||||
operation.hash_value = [
|
||||
f"{j}"
|
||||
for j in range(
|
||||
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
|
||||
)
|
||||
]
|
||||
|
||||
tik = time.monotonic()
|
||||
if hicache_mem_layout == "page_first":
|
||||
for operation in operations:
|
||||
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
|
||||
elif hicache_mem_layout == "layer_first":
|
||||
for operation in operations:
|
||||
cache_controller.generic_page_transfer(operation, batch_size=128)
|
||||
tok = time.monotonic()
|
||||
print(f"{tok-tik:.6f} s")
|
||||
Reference in New Issue
Block a user