3fs zerocopy (#9109)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
pansicheng
2025-08-22 17:56:38 +08:00
committed by GitHub
parent cebf45994b
commit 70cf4abccc
7 changed files with 310 additions and 29 deletions

View File

@@ -8,6 +8,9 @@ from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsLocalMetadataClient,
)
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
@@ -67,12 +70,15 @@ def test():
k = f"key_{i}"
v = torch.randn((numel,)).to(dtype=dtype)
ok = hicache_hf3fs.set(k, v)
assert ok, f"Failed to insert {k}"
if i < (file_size // bytes_per_page):
assert ok, f"Failed to insert {k}"
else:
assert not ok
tensors[k] = v
assert hicache_hf3fs.get("key_0") is None
assert hicache_hf3fs.get("key_1") is None
assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_9") is None
start = num_pages - hicache_hf3fs.num_pages
start = 0
for i in range(start, start + hicache_hf3fs.num_pages):
k = f"key_{i}"
assert hicache_hf3fs.exists(k)
@@ -83,13 +89,16 @@ def test():
assert not hicache_hf3fs.exists("not_exists")
hicache_hf3fs.delete("key_9")
hicache_hf3fs.delete("key_7")
v2 = torch.randn((numel,)).to(dtype=dtype)
assert hicache_hf3fs.set("key_new", v2)
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
hicache_hf3fs.clear()
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages
assert (
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
)
# batch
num_pages = 10
@@ -134,12 +143,14 @@ def bench():
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
@@ -167,7 +178,10 @@ def bench():
r_bw = []
r_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
tik = time.perf_counter()
results = hicache_hf3fs.batch_get(keys)
tok = time.perf_counter()
@@ -195,12 +209,14 @@ def allclose():
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
@@ -218,7 +234,10 @@ def allclose():
read_keys, read_results = [], []
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
results = hicache_hf3fs.batch_get(keys)
read_keys.extend(keys)
read_results.extend(results)