diff --git a/benchmark/hf3fs/bench.sh b/benchmark/hf3fs/bench.sh new file mode 100644 index 000000000..bb1bbcd32 --- /dev/null +++ b/benchmark/hf3fs/bench.sh @@ -0,0 +1,49 @@ +SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \ +python3 benchmark/hf3fs/bench_storage.py + +#################################################################################################### + +rm -rf nohup.out && \ +nohup python3 -m sglang.launch_server \ + --model-path /code/models/Qwen3-32B/ \ + --host 0.0.0.0 --port 33301 \ + --page-size 64 \ + --enable-hierarchical-cache \ + --hicache-ratio 2 --hicache-size 0 \ + --hicache-write-policy write_through \ + --hicache-storage-backend hf3fs & + +rm -rf bench_multiturn.out && \ +nohup python3 benchmark/hicache/bench_multiturn.py \ + --model-path /code/models/Qwen3-32B \ + --dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \ + --port 33301 \ + --request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \ + > bench_multiturn.out & + +#################################################################################################### + +rm -rf nohup.out && \ +nohup python3 -m sglang.launch_server \ + --model-path /code/models/DeepSeek-R1/ \ + --tp 16 --nnodes 2 --node-rank 0 \ + --dist-init-addr 10.74.249.153:5000 \ + --host 0.0.0.0 --port 33301 \ + --page-size 64 \ + --enable-hierarchical-cache \ + --hicache-ratio 2 --hicache-size 60 \ + --hicache-write-policy write_through \ + --hicache-storage-backend hf3fs & + +rm -rf bench_multiturn.out && \ +nohup python3 benchmark/hicache/bench_multiturn.py \ + --model-path /code/models/Qwen3-32B \ + --dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \ + --port 33301 \ + --request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \ + > bench_multiturn.out & + +#################################################################################################### + +ps aux | grep "sglang.launch_server" | grep -v grep | awk '{print $2}' | xargs kill -9 +ps aux | grep "bench_multiturn.py" | grep -v grep | awk '{print $2}' | xargs kill -9 diff --git a/benchmark/hf3fs/bench_client.py b/benchmark/hf3fs/bench_client.py new file mode 100644 index 000000000..33c502575 --- /dev/null +++ b/benchmark/hf3fs/bench_client.py @@ -0,0 +1,162 @@ +import concurrent.futures +import logging +import random +import time +from typing import List + +import torch +from tqdm import tqdm + +from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient + + +def print_stats(x: List[int]): + x = sorted(x) + lenx = len(x) + print( + f"mean = {sum(x)/len(x):.2f}, " + f"min = {min(x):.2f}, " + f"p25 = {x[int(lenx*0.25)]:.2f}, " + f"p50 = {x[int(lenx*0.5)]:.2f}, " + f"p75 = {x[int(lenx*0.75)]:.2f}, " + f"max = {max(x):.2f}" + ) + + +def test(): + # /path/to/hf3fs + file_path = "/data/bench.bin" + file_size = 1 << 40 + bytes_per_page = 16 << 20 + entries = 32 + file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries) + + print("test batch_read / batch_write") + num_pages = 128 + dtype = torch.bfloat16 + numel = bytes_per_page // dtype.itemsize + offsets = list(range(file_size // bytes_per_page)) + random.shuffle(offsets) + offsets = offsets[:num_pages] + offsets = [i * bytes_per_page for i in offsets] + tensor_writes = [ + torch.randn(numel, dtype=dtype) + for _ in tqdm(range(num_pages), desc="prepare tensor") + ] + for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"): + results = file_ops.batch_write( + offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries] + ) + assert all([result == numel * dtype.itemsize for result in results]) + tensor_reads = [ + torch.empty(numel, dtype=dtype) + for _ in tqdm(range(num_pages), desc="prepare tensor") + ] + for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"): + results = file_ops.batch_read( + offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries] + ) + assert all([result == numel * dtype.itemsize for result in results]) + assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)]) + + file_ops.close() + print("test done") + + +def bench(): + file_path = "/data/bench.bin" + file_size = 1 << 40 + bytes_per_page = 16 << 20 + entries = 8 + numjobs = 16 + + dtype = torch.bfloat16 + numel = bytes_per_page // dtype.itemsize + + file_ops = [ + Hf3fsClient(file_path, file_size, bytes_per_page, entries) + for _ in range(numjobs) + ] + + num_page = entries + + offsets = list(range(file_size // bytes_per_page)) + tensors_write = [torch.randn(numel, dtype=dtype)] * num_page + tensors_read = [torch.empty(numel, dtype=dtype)] * num_page + random.shuffle(offsets) + + warmup = 50 + iteration = 100 + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs) + + w_bw = [] + w_size = num_page * numjobs * bytes_per_page / (1 << 30) + for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"): + _offsets = [ + [ + offset * bytes_per_page + for offset in offsets[ + (i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page + ] + ] + for j in range(numjobs) + ] + tik = time.perf_counter() + futures = [ + executor.submit(file_ops[j].batch_write, offset, tensors_write) + for j, offset in enumerate(_offsets) + ] + results = [future.result() for future in futures] + tok = time.perf_counter() + if i < warmup: + continue + w_bw.append(w_size / (tok - tik)) + results = [ + _result == bytes_per_page for result in results for _result in result + ] + assert all(results) + print_stats(w_bw) + + r_bw = [] + r_size = w_size + for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"): + _offsets = [ + [ + offset * bytes_per_page + for offset in offsets[ + (i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page + ] + ] + for j in range(numjobs) + ] + tik = time.perf_counter() + futures = [ + executor.submit(file_ops[j].batch_read, offset, tensors_read) + for j, offset in enumerate(_offsets) + ] + results = [future.result() for future in futures] + tok = time.perf_counter() + if i < warmup: + continue + r_bw.append(r_size / (tok - tik)) + results = [ + _result == bytes_per_page for result in results for _result in result + ] + assert all(results) + print_stats(r_bw) + + executor.shutdown(wait=True) + for _file_ops in file_ops: + _file_ops.close() + print("bench done") + + +def main(): + logging.basicConfig(level=logging.INFO) + test() + bench() + + +if __name__ == "__main__": + main() diff --git a/benchmark/hf3fs/bench_storage.py b/benchmark/hf3fs/bench_storage.py new file mode 100644 index 000000000..4e96c8ec9 --- /dev/null +++ b/benchmark/hf3fs/bench_storage.py @@ -0,0 +1,241 @@ +import json +import logging +import os +import random +import time +from typing import List + +import torch +from tqdm import tqdm + +from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS + + +def print_stats(x: List[int]): + x = sorted(x) + lenx = len(x) + print( + f"mean = {sum(x)/len(x):.2f}, " + f"min = {min(x):.2f}, " + f"p25 = {x[int(lenx*0.25)]:.2f}, " + f"p50 = {x[int(lenx*0.5)]:.2f}, " + f"p75 = {x[int(lenx*0.75)]:.2f}, " + f"max = {max(x):.2f}" + ) + + +def test(): + # Qwen3-32B + layer_num = 64 + head_num, head_dim = 8, 128 + kv_lora_rank, qk_rope_head_dim = 0, 0 + store_dtype = torch.bfloat16 + tokens_per_page = 64 + + file_path_prefix = "/data/test" + file_size = 128 << 20 + numjobs = 16 + bytes_per_page = 16 << 20 + entries = 2 + dtype = store_dtype + + config_path = os.getenv(HiCacheHF3FS.default_env_var) + assert config_path + try: + with open(config_path, "w") as f: + json.dump( + { + "file_path_prefix": file_path_prefix, + "file_size": file_size, + "numjobs": numjobs, + "entries": entries, + }, + f, + ) + except Exception as e: + raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}") + + rank = 0 + hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype) + + numel = 2 * tokens_per_page * layer_num * head_num * head_dim + assert numel * dtype.itemsize == bytes_per_page + + num_pages = 10 + tensors = {} + for i in range(num_pages): + k = f"key_{i}" + v = torch.randn((numel,)).to(dtype=dtype) + ok = hicache_hf3fs.set(k, v) + assert ok, f"Failed to insert {k}" + tensors[k] = v + assert hicache_hf3fs.get("key_0") is None + assert hicache_hf3fs.get("key_1") is None + + start = num_pages - hicache_hf3fs.num_pages + for i in range(start, start + hicache_hf3fs.num_pages): + k = f"key_{i}" + assert hicache_hf3fs.exists(k) + out = hicache_hf3fs.get(k) + assert out is not None + v = tensors[k] + assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}" + + assert not hicache_hf3fs.exists("not_exists") + + hicache_hf3fs.delete("key_9") + v2 = torch.randn((numel,)).to(dtype=dtype) + assert hicache_hf3fs.set("key_new", v2) + assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3) + + hicache_hf3fs.clear() + assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages + + # batch + num_pages = 10 + tensors = {} + keys = [] + values = [] + for i in range(num_pages): + k = f"key_{i}" + keys.append(k) + v = torch.randn((numel,)).to(dtype=dtype) + values.append(v) + + ok = hicache_hf3fs.batch_set(keys, values) + assert not ok + assert hicache_hf3fs.get("key_8") is None + assert hicache_hf3fs.get("key_9") is None + + results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages]) + for result, key, value in zip( + results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages] + ): + assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}" + + hicache_hf3fs.close() + os.remove(hicache_hf3fs.file_path) + + print("All test cases passed.") + + +def bench(): + # Qwen3-32B + layer_num = 64 + head_num, head_dim = 8, 128 + kv_lora_rank, qk_rope_head_dim = 0, 0 + store_dtype = torch.bfloat16 + tokens_per_page = 64 + + file_path = "/data/test.bin" + file_size = 1 << 40 + numjobs = 16 + bytes_per_page = 16 << 20 + entries = 8 + dtype = store_dtype + hicache_hf3fs = HiCacheHF3FS( + file_path=file_path, + file_size=file_size, + numjobs=numjobs, + bytes_per_page=bytes_per_page, + entries=entries, + dtype=dtype, + ) + + numel = 2 * tokens_per_page * layer_num * head_num * head_dim + assert numel * dtype.itemsize == bytes_per_page + + num_page = 128 + values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))] + + warmup = 50 + iteration = 100 + + w_bw = [] + w_size = num_page * bytes_per_page / (1 << 30) + for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"): + keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)] + tik = time.perf_counter() + ok = hicache_hf3fs.batch_set(keys, values) + tok = time.perf_counter() + if i < warmup: + continue + w_bw.append(w_size / (tok - tik)) + assert ok + print_stats(w_bw) + + r_bw = [] + r_size = num_page * bytes_per_page / (1 << 30) + for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"): + keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page) + tik = time.perf_counter() + results = hicache_hf3fs.batch_get(keys) + tok = time.perf_counter() + if i < warmup: + continue + r_bw.append(r_size / (tok - tik)) + assert all([r is not None for r in results]) + print_stats(r_bw) + + hicache_hf3fs.close() + + +def allclose(): + # Qwen3-32B + layer_num = 64 + head_num, head_dim = 8, 128 + kv_lora_rank, qk_rope_head_dim = 0, 0 + store_dtype = torch.bfloat16 + tokens_per_page = 64 + + file_path = "/data/test.bin" + file_size = 1 << 40 + numjobs = 16 + bytes_per_page = 16 << 20 + entries = 8 + dtype = store_dtype + hicache_hf3fs = HiCacheHF3FS( + file_path=file_path, + file_size=file_size, + numjobs=numjobs, + bytes_per_page=bytes_per_page, + entries=entries, + dtype=dtype, + ) + + numel = 2 * tokens_per_page * layer_num * head_num * head_dim + assert numel * dtype.itemsize == bytes_per_page + + num_page = 128 + values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))] + + iteration = 100 + + for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"): + keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)] + ok = hicache_hf3fs.batch_set(keys, values) + assert ok + + read_keys, read_results = [], [] + for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"): + keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page) + results = hicache_hf3fs.batch_get(keys) + read_keys.extend(keys) + read_results.extend(results) + assert all([r is not None for r in results]) + + for key, result in tqdm(zip(read_keys, read_results)): + assert torch.allclose(values[int(key) % num_page], result, atol=1e-3) + + hicache_hf3fs.close() + + +def main(): + logging.basicConfig(level=logging.INFO) + test() + bench() + allclose() + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index fb7ad794f..629e77748 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str +from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS logger = logging.getLogger(__name__) @@ -250,17 +251,33 @@ class HiCacheController: self.tp_world_size = torch.distributed.get_world_size(group=tp_group) if self.tp_world_size > 1: group_ranks = torch.distributed.get_process_group_ranks(tp_group) - self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo") + self.prefetch_tp_group = torch.distributed.new_group( + group_ranks, backend="gloo" + ) + self.backup_tp_group = torch.distributed.new_group( + group_ranks, backend="gloo" + ) if storage_backend == "file": self.storage_backend = HiCacheFile() - self.enable_storage = True - # todo: threshold policy for prefetching - self.prefetch_threshold = max(prefetch_threshold, self.page_size) + elif storage_backend == "hf3fs": + from sglang.srt.distributed import get_tensor_model_parallel_rank + + rank = get_tensor_model_parallel_rank() + bytes_per_page = ( + mem_pool_host.get_size_per_token() * mem_pool_host.page_size + ) + dtype = mem_pool_host.dtype + self.storage_backend = HiCacheHF3FS.from_env_config( + rank, bytes_per_page, dtype + ) else: raise NotImplementedError( f"Unsupported storage backend: {storage_backend}" ) + self.enable_storage = True + # todo: threshold policy for prefetching + self.prefetch_threshold = max(prefetch_threshold, self.page_size) self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) @@ -522,8 +539,8 @@ class HiCacheController: while not self.stop_event.is_set(): try: operation = self.prefetch_buffer.get(block=True, timeout=1) - for h in operation.hash_value: - page_data = self.storage_backend.get(h) + page_datas = self.storage_backend.batch_get(operation.hash_value) + for h, page_data in zip(operation.hash_value, page_datas): if page_data is None: logger.warning( f"Prefetch operation {operation.request_id} failed to retrieve page {h}." @@ -531,7 +548,9 @@ class HiCacheController: break if operation.increment(self.page_size): self.mem_pool_host.set_from_flat_data_page( - operation.host_indices[operation.completed_tokens], + operation.host_indices[ + operation.completed_tokens - self.page_size + ], page_data, ) else: @@ -583,7 +602,7 @@ class HiCacheController: torch.distributed.all_reduce( storage_hit_count_tensor, op=torch.distributed.ReduceOp.MIN, - group=self.tp_group, + group=self.prefetch_tp_group, ) storage_hit_count = storage_hit_count_tensor.item() @@ -635,21 +654,23 @@ class HiCacheController: last_hash = operation.last_hash tokens_to_backup = operation.token_ids + last_hashes, data_pages = [], [] for i in range(0, len(tokens_to_backup), self.page_size): last_hash = get_hash_str( tokens_to_backup[i : i + self.page_size], last_hash ) - success = self.storage_backend.set( - last_hash, - self.mem_pool_host.get_flat_data_page( - operation.host_indices[i] - ), + data_page = self.mem_pool_host.get_flat_data_page( + operation.host_indices[i] ) - if not success: - logger.warning(f"Failed to write page {last_hash} to storage.") - break - operation.completed_tokens += self.page_size - operation.hash_value.append(last_hash) + last_hashes.append(last_hash) + data_pages.append(data_page) + + success = self.storage_backend.batch_set(last_hashes, data_pages) + if not success: + logger.warning(f"Failed to write page {last_hashes} to storage.") + else: + operation.completed_tokens += len(tokens_to_backup) + operation.hash_value.extend(last_hashes) min_completed_tokens = operation.completed_tokens if self.tp_world_size > 1: @@ -659,7 +680,7 @@ class HiCacheController: torch.distributed.all_reduce( completed_tokens_tensor, op=torch.distributed.ReduceOp.MIN, - group=self.tp_group, + group=self.backup_tp_group, ) min_completed_tokens = completed_tokens_tensor.item() diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index e6acbe9cc..f939fff4b 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -79,7 +79,9 @@ class HiRadixCache(RadixCache): self.write_through_threshold = ( 1 if hicache_write_policy == "write_through" else 3 ) - self.write_through_threshold_storage = 3 + self.write_through_threshold_storage = ( + 1 if hicache_write_policy == "write_through" else 3 + ) self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False @@ -388,10 +390,14 @@ class HiRadixCache(RadixCache): self.cache_controller.ack_backup_queue.get() ) host_node = self.ongoing_backup[ack_id] - if completed_tokens < len(host_node.key): + if completed_tokens == 0: + host_node.hash_value = None + elif completed_tokens < len(host_node.key): # backup is only partially successful, split the node new_node = self._split_node(host_node.key, host_node, completed_tokens) new_node.hash_value = hash_value + else: + host_node.hash_value = hash_value host_node.release_host() del self.ongoing_backup[ack_id] @@ -431,6 +437,8 @@ class HiRadixCache(RadixCache): written_indices, hash_value[:min_completed_tokens], ) + if len(written_indices): + self.cache_controller.mem_pool_host.update_prefetch(written_indices) self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) self.cache_controller.mem_pool_host.free( diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 0116e7141..c2fb4fa46 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -25,7 +25,6 @@ def synchronized(debug_only=False): @wraps(func) def wrapper(self, *args, **kwargs): if (not debug_only) or self.debug: - return func(self, *args, **kwargs) with self.lock: return func(self, *args, **kwargs) else: @@ -181,6 +180,15 @@ class HostKVCache(abc.ABC): ) self.mem_state[indices] = MemoryStateInt.BACKUP + @synchronized(debug_only=True) + def update_prefetch(self, indices: torch.Tensor): + if not self.is_reserved(indices): + raise ValueError( + f"The host memory slots should be in RESERVED state before turning into BACKUP. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.BACKUP + @synchronized(debug_only=True) def update_synced(self, indices: torch.Tensor): self.mem_state[indices] = MemoryStateInt.SYNCED diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/README.md b/python/sglang/srt/mem_cache/storage/hf3fs/README.md new file mode 100644 index 000000000..5fa1fa4c2 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/README.md @@ -0,0 +1,65 @@ +# HiCacheHF3FS Setup + +## Build & Package +### Source Code +https://github.com/deepseek-ai/3FS/blob/main/README.md#check-out-source-code +```sh +git clone https://github.com/deepseek-ai/3fs + +cd 3fs +git submodule update --init --recursive +./patches/apply.sh +``` + +### Build Dev Container +https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile +```sh +cd 3fs/dockerfile +docker build -t hf3fs:dev -f dev.dockerfile . +``` + +### Generate Python Wheel +```sh +docker run -it hf3fs:dev bash + +# Inside the development container +git clone https://github.com/deepseek-ai/3fs + +cd 3fs +git submodule update --init --recursive +./patches/apply.sh + +apt-get update \ +&& apt-get install -y --no-install-recommends \ +python3 python3-pip \ +&& apt-get clean \ +&& rm -rf /var/lib/apt/lists/* + +# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl +python3 setup.py bdist_wheel +``` + +## Installation +```sh +# Install Dependencies +# https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile +apt update && apt install -y \ + libaio-dev \ + libboost-all-dev \ + libdouble-conversion-dev \ + libdwarf-dev \ + libgflags-dev \ + libgmock-dev \ + libgoogle-glog-dev \ + libgoogle-perftools-dev \ + libgtest-dev \ + liblz4-dev \ + liblzma-dev \ + libssl-dev \ + libunwind-dev \ + libuv1-dev + +# Install Python Package +pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages +``` diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py new file mode 100644 index 000000000..09832b8e2 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py @@ -0,0 +1,177 @@ +import logging +import multiprocessing +import os +import threading +from functools import wraps +from pathlib import Path +from typing import List + +import torch +from torch.utils.cpp_extension import load + +root = Path(__file__).parent.resolve() +hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"]) + +logger = logging.getLogger(__name__) + +try: + from hf3fs_fuse.io import ( + deregister_fd, + extract_mount_point, + make_ioring, + make_iovec, + register_fd, + ) +except ImportError as e: + logger.warning(f"hf3fs_fuse.io is not available: {e}") + + +def rsynchronized(): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.rlock: + return func(self, *args, **kwargs) + + return wrapper + + return _decorator + + +def wsynchronized(): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.wlock: + return func(self, *args, **kwargs) + + return wrapper + + return _decorator + + +class Hf3fsClient: + def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): + self.path = path + self.size = size + self.bytes_per_page = bytes_per_page + self.entries = entries + + self.file = os.open(self.path, os.O_RDWR | os.O_CREAT) + os.ftruncate(self.file, size) + register_fd(self.file) + + self.hf3fs_mount_point = extract_mount_point(path) + self.bs = self.bytes_per_page + self.shm_r = multiprocessing.shared_memory.SharedMemory( + size=self.bs * self.entries, create=True + ) + self.shm_w = multiprocessing.shared_memory.SharedMemory( + size=self.bs * self.entries, create=True + ) + + self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8) + self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8) + + self.numa = -1 + self.ior_r = make_ioring( + self.hf3fs_mount_point, + self.entries, + for_read=True, + timeout=1, + numa=self.numa, + ) + self.ior_w = make_ioring( + self.hf3fs_mount_point, + self.entries, + for_read=False, + timeout=1, + numa=self.numa, + ) + self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point) + self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point) + + self.rlock = threading.RLock() + self.wlock = threading.RLock() + + @rsynchronized() + def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]: + self.check(offsets, tensors) + + # prepare + current = 0 + for offset, tensor in zip(offsets, tensors): + size = tensor.numel() * tensor.itemsize + self.ior_r.prepare( + self.iov_r[current : current + size], True, self.file, offset + ) + current += size + + # submit + ionum = len(offsets) + resv = self.ior_r.submit().wait(min_results=ionum) + + # results + hf3fs_utils.read_shm(self.shm_r_tensor, tensors) + results = [res.result for res in resv] + + return results + + @wsynchronized() + def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]: + self.check(offsets, tensors) + + # prepare + hf3fs_utils.write_shm(tensors, self.shm_w_tensor) + current = 0 + for offset, tensor in zip(offsets, tensors): + size = tensor.numel() * tensor.itemsize + self.ior_w.prepare( + self.iov_w[current : current + size], False, self.file, offset + ) + current += size + + # submit + ionum = len(offsets) + resv = self.ior_w.submit().wait(min_results=ionum) + + # results + results = [res.result for res in resv] + + return results + + def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None: + sizes = [t.numel() * t.itemsize for t in tensors] + if any( + [ + len(offsets) > self.entries, + len(offsets) != len(sizes), + all( + [ + offset < 0 or offset + size > self.size + for offset, size in zip(offsets, sizes) + ] + ), + all([size > self.bytes_per_page for size in sizes]), + ] + ): + self.close() + raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}") + + def get_size(self) -> int: + return self.size + + def close(self) -> None: + deregister_fd(self.file) + os.close(self.file) + del self.ior_r + del self.ior_w + del self.iov_r + del self.iov_w + self.shm_r.close() + self.shm_w.close() + self.shm_r.unlink() + self.shm_w.unlink() + + def flush(self) -> None: + os.fsync(self.file) diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp new file mode 100644 index 000000000..3a4b7dcc0 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp @@ -0,0 +1,35 @@ +#include + +#include +#include + +void read_shm(const torch::Tensor &shm, std::vector dst) { + py::gil_scoped_release release; + char *src_ptr = static_cast(shm.data_ptr()); + size_t current = 0; + for (size_t i = 0; i < dst.size(); ++i) { + auto &t = dst[i]; + size_t t_bytes = t.numel() * t.element_size(); + char *dst_ptr = static_cast(t.data_ptr()); + std::memcpy(dst_ptr, src_ptr + current, t_bytes); + current += t_bytes; + } +} + +void write_shm(const std::vector src, torch::Tensor &shm) { + py::gil_scoped_release release; + char *dst_ptr = static_cast(shm.data_ptr()); + size_t current = 0; + for (size_t i = 0; i < src.size(); ++i) { + auto &t = src[i]; + size_t t_bytes = t.numel() * t.element_size(); + char *src_ptr = static_cast(t.data_ptr()); + std::memcpy(dst_ptr + current, src_ptr, t_bytes); + current += t_bytes; + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("read_shm", &read_shm, "Read tensors from shared memory"); + m.def("write_shm", &write_shm, "Write tensors to shared memory"); +} diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py new file mode 100644 index 000000000..0cc2b0a26 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -0,0 +1,278 @@ +import atexit +import concurrent.futures +import json +import logging +import os +import signal +import threading +from collections import OrderedDict +from functools import wraps +from typing import List, Optional + +import torch + +from sglang.srt.mem_cache.hicache_storage import HiCacheStorage +from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient + +logger = logging.getLogger(__name__) + + +class AtomicCounter: + def __init__(self, n: int): + assert n > 0 + self.n = n + self._value = 0 + self._lock = threading.Lock() + + def next(self) -> int: + with self._lock: + current = self._value + self._value = (current + 1) % self.n + return current + + +def synchronized(): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + + return wrapper + + return _decorator + + +class HiCacheHF3FS(HiCacheStorage): + default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH" + + def __init__( + self, + file_path: str, + file_size: int, + numjobs: int, + bytes_per_page: int, + entries: int, + dtype: torch.dtype, + ): + self.file_path = file_path + self.file_size = file_size + self.numjobs = numjobs + self.bytes_per_page = bytes_per_page + self.entries = entries + self.dtype = dtype + + self.numel = self.bytes_per_page // self.dtype.itemsize + + self.num_pages = self.file_size // self.bytes_per_page + + logger.info( + "HiCacheHF3FS " + f"file_path = {self.file_path}, " + f"file_size = {self.file_size/(2**30):.2f} GB, " + f"numjobs = {self.numjobs}, " + f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, " + f"entries = {self.entries}, " + f"num_pages = {self.num_pages}" + ) + + self.ac = AtomicCounter(self.numjobs) + self.clients = [ + Hf3fsClient( + self.file_path, self.file_size, self.bytes_per_page, self.entries + ) + for _ in range(numjobs) + ] + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS" + ) + + # Implemented a preliminary single-file page_hash -> file_offset index as interim storage. + # Future iterations may adopt a global KVCache manager to coordinate external cache instances + # through centralized metadata orchestration. + self.lock = threading.RLock() + self.free_pages = list(range(self.num_pages)) + self.key_to_index = OrderedDict() + + atexit.register(self.close) + + signal.signal(signal.SIGINT, lambda sig, frame: self.close()) + signal.signal(signal.SIGTERM, lambda sig, frame: self.close()) + signal.signal(signal.SIGQUIT, lambda sig, frame: self.close()) + + @staticmethod + def from_env_config( + rank: int, bytes_per_page: int, dtype: torch.dtype + ) -> "HiCacheHF3FS": + config_path = os.getenv(HiCacheHF3FS.default_env_var) + if not config_path: + return HiCacheHF3FS( + file_path=f"/data/hicache.{rank}.bin", + file_size=1 << 40, + numjobs=16, + bytes_per_page=bytes_per_page, + entries=8, + dtype=dtype, + ) + + try: + with open(config_path, "r") as f: + config = json.load(f) + except Exception as e: + raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}") + + required_keys = { + "file_path_prefix", + "file_size", + "numjobs", + "entries", + } + missing_keys = required_keys - set(config.keys()) + if missing_keys: + raise ValueError(f"Missing required keys in config: {missing_keys}") + + return HiCacheHF3FS( + file_path=f"{config['file_path_prefix']}.{rank}.bin", + file_size=int(config["file_size"]), + numjobs=int(config["numjobs"]), + bytes_per_page=bytes_per_page, + entries=int(config["entries"]), + dtype=dtype, + ) + + def get( + self, key: str, target_location: Optional[torch.Tensor] = None + ) -> torch.Tensor | None: + return self.batch_get([key], target_location)[0] + + @synchronized() + def batch_get( + self, + keys: List[str], + target_locations: Optional[List[torch.Tensor]] = None, + ) -> List[torch.Tensor | None]: + batch_indices, file_offsets = [], [] + for i, key in enumerate(keys): + if key not in self.key_to_index: + continue + batch_indices.append(i) + file_offsets.append(self.key_to_index[key] * self.bytes_per_page) + self.key_to_index.move_to_end(key) + # TODO: target_locations + file_results = [ + torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices)) + ] + + futures = [ + self.executor.submit( + self.clients[self.ac.next()].batch_read, + file_offsets[i : i + self.entries], + file_results[i : i + self.entries], + ) + for i in range(0, len(batch_indices), self.entries) + ] + read_results = [result for future in futures for result in future.result()] + + results = [None] * len(keys) + for batch_index, file_result, read_result in zip( + batch_indices, file_results, read_results + ): + if read_result == self.bytes_per_page: + results[batch_index] = file_result + else: + logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed") + + return results + + def set(self, key: str, value: torch.Tensor) -> bool: + return self.batch_set([key], [value]) + + def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + indices = self.get_batch_set_indices(keys) + batch_indices, file_offsets, file_values = [], [], [] + for i, (value, (is_written, index)) in enumerate(zip(values, indices)): + if is_written or index == -1: + continue + batch_indices.append(i) + file_offsets.append(index * self.bytes_per_page) + file_values.append(value.contiguous()) + + futures = [ + self.executor.submit( + self.clients[self.ac.next()].batch_write, + file_offsets[i : i + self.entries], + file_values[i : i + self.entries], + ) + for i in range(0, len(batch_indices), self.entries) + ] + write_results = [ + result == self.bytes_per_page + for future in futures + for result in future.result() + ] + + results = [index[0] for index in indices] + for batch_index, write_result in zip(batch_indices, write_results): + key = keys[batch_index] + index = indices[batch_index][1] + if write_result: + self.key_to_index[key] = index + self.key_to_index.move_to_end(key) + else: + logger.error(f"HiCacheHF3FS set {key} failed") + self.free_pages.append(index) + results[batch_index] = write_result + return all(results) + + @synchronized() + def get_batch_set_indices(self, keys: List[str]) -> list: + ionum = len(keys) + # results: tuples of (is_written: bool, page_idx: int) + # - is_written: True = hit (no I/O), False = write (miss) + # - page_idx: page storing data + results = [None] * min(ionum, self.num_pages) + if ionum > self.num_pages: + results.extend([(False, -1)] * (ionum - self.num_pages)) + + new_keys = [] + for batch_index, key in enumerate(keys[: self.num_pages]): + if key in self.key_to_index: + results[batch_index] = (True, self.key_to_index[key]) + self.key_to_index.move_to_end(key) + else: + new_keys.append((batch_index, key)) + + for batch_index, _ in new_keys: + index = ( + self.free_pages.pop() + if len(self.free_pages) > 0 + else self.key_to_index.popitem(last=False)[1] + ) + results[batch_index] = (False, index) + + return results + + @synchronized() + def delete(self, key: str) -> None: + if key not in self.key_to_index: + return + index = self.key_to_index.pop(key) + self.free_pages.append(index) + + @synchronized() + def exists(self, key: str) -> bool: + return key in self.key_to_index + + @synchronized() + def clear(self) -> None: + self.free_pages = list(range(self.num_pages)) + self.key_to_index.clear() + + def close(self) -> None: + try: + for c in self.clients: + c.close() + self.executor.shutdown(wait=True) + except Exception as e: + logger.error(f"close HiCacheHF3FS: {e}") + logger.info("close HiCacheHF3FS") diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py b/python/sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py new file mode 100644 index 000000000..365effdef --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py @@ -0,0 +1,43 @@ +import multiprocessing.shared_memory +from pathlib import Path + +import pytest +import torch +from torch.utils.cpp_extension import load +from tqdm import tqdm + +root = Path(__file__).parent.resolve() +hf3fs_utils = load( + name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True +) + + +def test_rw_shm(): + numel = 8 << 20 + dtype = torch.bfloat16 + page_num = 128 + page_bytes = numel * dtype.itemsize + shm = multiprocessing.shared_memory.SharedMemory( + size=page_num * page_bytes, create=True + ) + tshm = torch.frombuffer(shm.buf, dtype=torch.uint8) + a = [ + torch.randn(numel, dtype=dtype) + for _ in tqdm(range(page_num), desc="prepare input") + ] + b = [ + torch.empty(numel, dtype=dtype) + for _ in tqdm(range(page_num), desc="prepare output") + ] + hf3fs_utils.write_shm(a, tshm) + hf3fs_utils.read_shm(tshm, b) + for _a, _b in tqdm(zip(a, b), desc="assert_close"): + torch.testing.assert_close(_a, _b) + + del tshm + shm.close() + shm.unlink() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 856d68138..d53558211 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1476,7 +1476,7 @@ class ServerArgs: parser.add_argument( "--hicache-storage-backend", type=str, - choices=["file"], # todo, mooncake + choices=["file", "hf3fs"], # todo, mooncake default=ServerArgs.hicache_storage_backend, help="The storage backend for hierarchical KV cache.", )