diff --git a/benchmark/hf3fs/bench.sh b/benchmark/hf3fs/bench.sh index bb1bbcd32..049116b89 100644 --- a/benchmark/hf3fs/bench.sh +++ b/benchmark/hf3fs/bench.sh @@ -1,6 +1,16 @@ +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib +python3 benchmark/hf3fs/bench_client.py + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \ python3 benchmark/hf3fs/bench_storage.py +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib +export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json +echo '{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}' > \ +${SGLANG_HICACHE_HF3FS_CONFIG_PATH} +python3 benchmark/hf3fs/bench_zerocopy.py + #################################################################################################### rm -rf nohup.out && \ diff --git a/benchmark/hf3fs/bench_storage.py b/benchmark/hf3fs/bench_storage.py index 4e96c8ec9..30702b635 100644 --- a/benchmark/hf3fs/bench_storage.py +++ b/benchmark/hf3fs/bench_storage.py @@ -8,6 +8,9 @@ from typing import List import torch from tqdm import tqdm +from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( + Hf3fsLocalMetadataClient, +) from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS @@ -67,12 +70,15 @@ def test(): k = f"key_{i}" v = torch.randn((numel,)).to(dtype=dtype) ok = hicache_hf3fs.set(k, v) - assert ok, f"Failed to insert {k}" + if i < (file_size // bytes_per_page): + assert ok, f"Failed to insert {k}" + else: + assert not ok tensors[k] = v - assert hicache_hf3fs.get("key_0") is None - assert hicache_hf3fs.get("key_1") is None + assert hicache_hf3fs.get("key_8") is None + assert hicache_hf3fs.get("key_9") is None - start = num_pages - hicache_hf3fs.num_pages + start = 0 for i in range(start, start + hicache_hf3fs.num_pages): k = f"key_{i}" assert hicache_hf3fs.exists(k) @@ -83,13 +89,16 @@ def test(): assert not hicache_hf3fs.exists("not_exists") - hicache_hf3fs.delete("key_9") + hicache_hf3fs.delete("key_7") v2 = torch.randn((numel,)).to(dtype=dtype) assert hicache_hf3fs.set("key_new", v2) assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3) hicache_hf3fs.clear() - assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages + assert ( + len(hicache_hf3fs.metadata_client.rank_metadata.free_pages) + == hicache_hf3fs.metadata_client.rank_metadata.num_pages + ) # batch num_pages = 10 @@ -134,12 +143,14 @@ def bench(): entries = 8 dtype = store_dtype hicache_hf3fs = HiCacheHF3FS( + rank=0, file_path=file_path, file_size=file_size, numjobs=numjobs, bytes_per_page=bytes_per_page, entries=entries, dtype=dtype, + metadata_client=Hf3fsLocalMetadataClient(), ) numel = 2 * tokens_per_page * layer_num * head_num * head_dim @@ -167,7 +178,10 @@ def bench(): r_bw = [] r_size = num_page * bytes_per_page / (1 << 30) for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"): - keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page) + keys = random.sample( + list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()), + num_page, + ) tik = time.perf_counter() results = hicache_hf3fs.batch_get(keys) tok = time.perf_counter() @@ -195,12 +209,14 @@ def allclose(): entries = 8 dtype = store_dtype hicache_hf3fs = HiCacheHF3FS( + rank=0, file_path=file_path, file_size=file_size, numjobs=numjobs, bytes_per_page=bytes_per_page, entries=entries, dtype=dtype, + metadata_client=Hf3fsLocalMetadataClient(), ) numel = 2 * tokens_per_page * layer_num * head_num * head_dim @@ -218,7 +234,10 @@ def allclose(): read_keys, read_results = [], [] for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"): - keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page) + keys = random.sample( + list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()), + num_page, + ) results = hicache_hf3fs.batch_get(keys) read_keys.extend(keys) read_results.extend(results) diff --git a/benchmark/hf3fs/bench_zerocopy.py b/benchmark/hf3fs/bench_zerocopy.py new file mode 100644 index 000000000..bfa7bff0e --- /dev/null +++ b/benchmark/hf3fs/bench_zerocopy.py @@ -0,0 +1,140 @@ +import threading +import time + +import torch +from tqdm import tqdm + +from sglang.srt.distributed import ( + get_world_group, + init_distributed_environment, + initialize_model_parallel, +) +from sglang.srt.managers.cache_controller import ( + HiCacheController, + PrefetchOperation, + StorageOperation, +) +from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool +from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost + +init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method="tcp://127.0.0.1:23456", + local_rank=0, + backend="gloo", +) + +initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, +) + +group = get_world_group().cpu_group + +max_total_num_tokens = 524288 +page_size = 64 +kv_cache_dtype = torch.bfloat16 +layer_num = 64 +head_num, head_dim = 8, 128 +device = "cuda" +hicache_ratio = 2 +hicache_size = 0 +hicache_mem_layout = "page_first" +# hicache_mem_layout = "layer_first" +hicache_write_policy = "write_through" +hicache_io_backend = "kernel" +hicache_storage_backend = "hf3fs" +prefetch_threshold = 256 + +op_size = 1024 +op_num = 16 + +token_to_kv_pool = MHATokenToKVPool( + max_total_num_tokens, + page_size=page_size, + dtype=kv_cache_dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=layer_num, + device=device, + enable_memory_saver=True, +) + +token_to_kv_pool_allocator = TokenToKVPoolAllocator( + max_total_num_tokens, + dtype=kv_cache_dtype, + device=device, + kvcache=token_to_kv_pool, + need_sort=False, +) + +kv_cache = token_to_kv_pool_allocator.get_kvcache() +token_to_kv_pool_host = MHATokenToKVPoolHost( + kv_cache, + hicache_ratio, + hicache_size, + page_size, + hicache_mem_layout, +) + +load_cache_event = threading.Event() +cache_controller = HiCacheController( + token_to_kv_pool_allocator, + token_to_kv_pool_host, + page_size, + group, + load_cache_event=load_cache_event, + write_policy=hicache_write_policy, + io_backend=hicache_io_backend, + storage_backend=hicache_storage_backend, + prefetch_threshold=prefetch_threshold, +) + +operations = [ + StorageOperation( + torch.tensor(list(range(i, i + op_size))), + list(range(i, i + op_size)), + hash_value=[f"{j}" for j in range(i, i + op_size, page_size)], + ) + for i in tqdm(range(0, op_num * op_size, op_size)) +] + +tik = time.monotonic() +if hicache_mem_layout == "page_first": + for operation in operations: + cache_controller.zerocopy_page_backup(operation, batch_size=128) +elif hicache_mem_layout == "layer_first": + for operation in operations: + cache_controller.generic_page_backup(operation, batch_size=128) +tok = time.monotonic() +print(f"{tok-tik:.6f} s") + +operations = [ + PrefetchOperation( + f"{i}", + torch.tensor(list(range(i, i + op_size))), + list(range(i, i + op_size)), + f"{i}", + ) + for i in tqdm(range(0, op_num * op_size, op_size)) +] + +for operation in operations: + operation.hash_value = [ + f"{j}" + for j in range( + int(operation.last_hash), int(operation.last_hash) + op_size, page_size + ) + ] + +tik = time.monotonic() +if hicache_mem_layout == "page_first": + for operation in operations: + cache_controller.zerocopy_page_transfer(operation, batch_size=128) +elif hicache_mem_layout == "layer_first": + for operation in operations: + cache_controller.generic_page_transfer(operation, batch_size=128) +tok = time.monotonic() +print(f"{tok-tik:.6f} s") diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index b25bf4032..6ba9571b5 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -268,9 +268,14 @@ class HiCacheController: ) rank = get_tensor_model_parallel_rank() - bytes_per_page = ( - mem_pool_host.get_size_per_token() * mem_pool_host.page_size - ) + if self.mem_pool_host.layout == "page_first": + bytes_per_page = ( + mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size + ) + elif self.mem_pool_host.layout == "layer_first": + bytes_per_page = ( + mem_pool_host.get_size_per_token() * mem_pool_host.page_size + ) dtype = mem_pool_host.dtype self.storage_backend = HiCacheHF3FS.from_env_config( rank, bytes_per_page, dtype @@ -555,13 +560,34 @@ class HiCacheController: operation.mark_done() return operation.completed_tokens, operation.hash_value + def zerocopy_page_transfer(self, operation, batch_size=8): + hashes, dsts = self.mem_pool_host.get_buffer_with_hash( + operation.hash_value, operation.host_indices + ) + for i in range(0, len(hashes), batch_size): + page_hashes = hashes[i : i + batch_size] + page_dsts = dsts[i : i + batch_size] + page_data = self.storage_backend.batch_get(page_hashes, page_dsts) + if page_data is None: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}." + ) + break + completed_tokens = operation.completed_tokens + if operation.increment(self.page_size * len(page_hashes)): + for i in range(len(page_hashes)): + completed_tokens += self.page_size + else: + break + def generic_page_transfer(self, operation, batch_size=8): for i in range(0, len(operation.hash_value), batch_size): page_hashes = operation.hash_value[i : i + batch_size] # todo: zero copy - dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len( - page_hashes - ) + dummy_page_dst = [ + self.mem_pool_host.get_dummy_flat_data_page() + for _ in range(len(page_hashes)) + ] page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst) if page_data is None: logger.warning( @@ -599,7 +625,10 @@ class HiCacheController: if self.is_mooncake_backend(): self.mooncake_page_transfer(operation) elif self.storage_backend_type == "hf3fs": - self.generic_page_transfer(operation, batch_size=128) + if self.mem_pool_host.layout == "page_first": + self.zerocopy_page_transfer(operation, batch_size=128) + elif self.mem_pool_host.layout == "layer_first": + self.generic_page_transfer(operation, batch_size=128) else: self.generic_page_transfer(operation) @@ -716,6 +745,19 @@ class HiCacheController: self.backup_queue.put(operation) return operation.id + def zerocopy_page_backup(self, operation, batch_size=8): + hashes, dsts = self.mem_pool_host.get_buffer_with_hash( + operation.hash_value, operation.host_indices + ) + for i in range(0, len(hashes), batch_size): + page_hashes = hashes[i : i + batch_size] + page_data = dsts[i : i + batch_size] + success = self.storage_backend.batch_set(page_hashes, page_data) + if not success: + logger.warning(f"Failed to write page {page_hashes} to storage.") + break + operation.completed_tokens += self.page_size * len(page_hashes) + def generic_page_backup(self, operation, batch_size=8): for i in range(0, len(operation.hash_value), batch_size): page_hashes = operation.hash_value[i : i + batch_size] @@ -770,7 +812,10 @@ class HiCacheController: if self.is_mooncake_backend(): self.mooncake_page_backup(operation) elif self.storage_backend_type == "hf3fs": - self.generic_page_backup(operation, batch_size=128) + if self.mem_pool_host.layout == "page_first": + self.zerocopy_page_backup(operation, batch_size=128) + elif self.mem_pool_host.layout == "layer_first": + self.generic_page_backup(operation, batch_size=128) else: self.generic_page_backup(operation) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index cfc7f36c5..4abc6dc0a 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -307,6 +307,9 @@ class MHATokenToKVPoolHost(HostKVCache): return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 + def get_ksize_per_token(self): + return self.get_size_per_token() // 2 + def init_kv_buffer(self): if self.layout == "layer_first": dims = (2, self.layer_num, self.size, self.head_num, self.head_dim) @@ -496,6 +499,21 @@ class MHATokenToKVPoolHost(HostKVCache): element_size_list = [element_size] * len(key_list) return key_list, ptr_list, element_size_list + def get_buffer_with_hash(self, keys, indices): + assert self.layout == "page_first" + assert len(keys) == (len(indices) // self.page_size) + + key_list = [] + buf_list = [] + + for key, i in zip(keys, range(0, len(indices), self.page_size)): + key_list.append(f"{key}-k") + buf_list.append(self.k_buffer[i : i + self.page_size]) + key_list.append(f"{key}-v") + buf_list.append(self.v_buffer[i : i + self.page_size]) + + return key_list, buf_list + class MLATokenToKVPoolHost(HostKVCache): device_pool: MLATokenToKVPool @@ -538,6 +556,9 @@ class MLATokenToKVPoolHost(HostKVCache): * self.layer_num ) + def get_ksize_per_token(self): + return self.get_size_per_token() + def init_kv_buffer(self): if self.layout == "layer_first": dims = ( @@ -704,3 +725,14 @@ class MLATokenToKVPoolHost(HostKVCache): ) element_size_list = [element_size] * len(key_list) return key_list, ptr_list, element_size_list + + def get_buffer_with_hash(self, keys, indices): + assert self.layout == "page_first" + assert len(keys) == (len(indices) // self.page_size) + + buf_list = [] + + for i in range(0, len(indices), self.page_size): + buf_list.append(self.kv_buffer[i : i + self.page_size]) + + return keys, buf_list diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md b/python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md index 5fa1fa4c2..7c7c0bfb2 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md +++ b/python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md @@ -34,6 +34,9 @@ apt-get update \ python3 python3-pip \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* +# apt install python3.12 python3.12-venv python3.12-dev +# curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py +# python3.12 get-pip.py # Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl python3 setup.py bdist_wheel @@ -60,6 +63,6 @@ apt update && apt install -y \ libuv1-dev # Install Python Package -pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages +pip install hf3fs_py_usrbio-1.2.9+394583d-cp312-cp312-linux_x86_64.whl +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages ``` diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index e7dd01c73..b301ee0c8 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -7,7 +7,7 @@ import signal import threading from abc import ABC, abstractmethod from functools import wraps -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch @@ -228,15 +228,23 @@ class HiCacheHF3FS(HiCacheStorage): ) def get( - self, key: str, target_location: Optional[torch.Tensor] = None + self, + key: str, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, ) -> torch.Tensor | None: - return self.batch_get([key], [target_location] if target_location else None)[0] + return self.batch_get( + [key], + [target_location] if target_location is not None else None, + [target_sizes] if target_sizes is not None else None, + )[0] @synchronized() def batch_get( self, keys: List[str], - target_locations: Optional[List[torch.Tensor]] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, ) -> List[torch.Tensor | None]: page_indices = self.metadata_client.get_page_indices(self.rank, keys) @@ -246,9 +254,15 @@ class HiCacheHF3FS(HiCacheStorage): batch_indices.append(i) file_offsets.append(page_index * self.bytes_per_page) - file_results = [ - torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices)) - ] + if target_locations is not None: + for target_location in target_locations: + assert target_location.is_contiguous() + file_results = target_locations + else: + file_results = [ + torch.empty(self.numel, dtype=self.dtype) + for _ in range(len(batch_indices)) + ] futures = [ self.executor.submit( @@ -273,10 +287,27 @@ class HiCacheHF3FS(HiCacheStorage): return results - def set(self, key: str, value: torch.Tensor) -> bool: - return self.batch_set([key], [value]) + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + return self.batch_set( + [key], + [value] if value is not None else None, + [target_location] if target_location is not None else None, + [target_sizes] if target_sizes is not None else None, + ) - def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: # Todo: Add prefix block's hash key key_with_prefix = [(key, "") for key in keys] indices = self.metadata_client.reserve_and_allocate_page_indices( @@ -292,7 +323,8 @@ class HiCacheHF3FS(HiCacheStorage): batch_indices.append(i) file_offsets.append(page_index * self.bytes_per_page) - file_values.append(value.contiguous()) + assert value.is_contiguous() + file_values.append(value) futures = [ self.executor.submit(