From 86cb4db058ff368d50161101f97e2a5381b69d48 Mon Sep 17 00:00:00 2001 From: Shisong Ma <31835442+mss1213@users.noreply.github.com> Date: Wed, 1 Oct 2025 21:43:34 +0800 Subject: [PATCH] [Feature] Add EIC as sglang HiCache Storage backend (#10271) Co-authored-by: mashisong --- .../sglang/srt/managers/cache_controller.py | 2 +- .../srt/mem_cache/storage/backend_factory.py | 8 + .../srt/mem_cache/storage/eic/README.md | 24 + .../srt/mem_cache/storage/eic/eic_storage.py | 778 ++++++++++++++++++ .../srt/mem_cache/storage/eic/test_unit.py | 115 +++ python/sglang/srt/server_args.py | 2 +- 6 files changed, 927 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/mem_cache/storage/eic/README.md create mode 100644 python/sglang/srt/mem_cache/storage/eic/eic_storage.py create mode 100644 python/sglang/srt/mem_cache/storage/eic/test_unit.py diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index c2c14707e..c7d0218e5 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -310,7 +310,7 @@ class HiCacheController: self.page_get_func = self._generic_page_get self.page_set_func = self._generic_page_set - if self.storage_backend_type in ["hf3fs", "mooncake"]: + if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]: self.page_get_func = self._page_get_zero_copy self.page_set_func = self._page_set_zero_copy diff --git a/python/sglang/srt/mem_cache/storage/backend_factory.py b/python/sglang/srt/mem_cache/storage/backend_factory.py index a141afb21..3af598abb 100644 --- a/python/sglang/srt/mem_cache/storage/backend_factory.py +++ b/python/sglang/srt/mem_cache/storage/backend_factory.py @@ -181,6 +181,8 @@ class StorageBackendFactory: dtype = mem_pool_host.dtype return backend_class.from_env_config(bytes_per_page, dtype, storage_config) + elif backend_name == "eic": + return backend_class(storage_config, mem_pool_host) else: raise ValueError(f"Unknown built-in backend: {backend_name}") @@ -213,3 +215,9 @@ StorageBackendFactory.register_backend( "sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage", "AibrixKVCacheStorage", ) + +StorageBackendFactory.register_backend( + "eic", + "sglang.srt.mem_cache.storage.eic.eic_storage", + "EICStorage", +) diff --git a/python/sglang/srt/mem_cache/storage/eic/README.md b/python/sglang/srt/mem_cache/storage/eic/README.md new file mode 100644 index 000000000..1f91d0378 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/eic/README.md @@ -0,0 +1,24 @@ +# EIC as sglang HiCache Storage +EIC(Elastic Instant Cache) is a distributed database designed for LLM KV Cache. It supports RDMA, GDR and has the capabilities of distributed disaster tolerance and expansion. +You can understand the principles and architecture of EIC through these articles: https://mp.weixin.qq.com/s/tasDqXf0Gxr3o_WCJ2IJUQ https://mp.weixin.qq.com/s/b_4YhTa96Zeklh23lv8qBw + + +## Deploy EIC +You can visit the official link https://console.volcengine.com/eic and deploy EIC KVCache on your compute cluster with web UI.In addition, we provide particular image in volcano engine, which integrates various optimizations based on the official image. +You may use test_unit.py to detect the connectivity of EIC. + + + +## Deploy Model With EIC +You can enable EIC KVCache offload with the official interface, such as + +```bash +python -m sglang.launch_server \ + --model-path [model_path] + --enable-hierarchical-cache \ + --hicache-storage-backend eic \ + --hicache-write-policy 'write_through' \ + --hicache-mem-layout 'page_first' \ + +``` +For more details, you can see https://www.volcengine.com/docs/85848/1749188 . diff --git a/python/sglang/srt/mem_cache/storage/eic/eic_storage.py b/python/sglang/srt/mem_cache/storage/eic/eic_storage.py new file mode 100644 index 000000000..4c4ea89eb --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/eic/eic_storage.py @@ -0,0 +1,778 @@ +import json +import logging +import os +import time +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import eic +import torch +import yaml + +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorage, + HiCacheStorageConfig, + HiCacheStorageExtraInfo, +) +from sglang.srt.mem_cache.memory_pool_host import HostKVCache, MLATokenToKVPoolHost + +logger = logging.getLogger(__name__) + + +TensorPoolSize = 2048 + +REMOTE_EIC_YAML_ENV_VAR = "REMOTE_EIC_YAML" + +# gpu direct rdma for kv set +G_EnableKVSetGPUDirect = False + +# gpu direct rdma for kv get +G_EnableKVGetGPUDirect = False + +# gpu nic affinity +G_EnableGPUNicAffinity = False + +# default H20 gpu nic affinity +GPUNicAffinity = { + "cuda:0": "eth1", + "cuda:1": "eth1", + "cuda:2": "eth2", + "cuda:3": "eth2", + "cuda:4": "eth3", + "cuda:5": "eth3", + "cuda:6": "eth4", + "cuda:7": "eth4", +} + +# default H20 cpu nic affinity +CPUNicAffinity = { + "cuda:0": "cpu", + "cuda:1": "cpu", + "cuda:2": "cpu", + "cuda:3": "cpu", + "cuda:4": "cpu", + "cuda:5": "cpu", + "cuda:6": "cpu", + "cuda:7": "cpu", +} + + +def get_eic_config_file_path(): + if os.environ.get(REMOTE_EIC_YAML_ENV_VAR) is not None: + logger.info(f"eic init with env var {REMOTE_EIC_YAML_ENV_VAR}") + config_file = os.environ.get(REMOTE_EIC_YAML_ENV_VAR) + else: + config_file = "/sgl-workspace/config/remote-eic.yaml" + logger.info(f"eic init with default config, config_file {config_file}") + return config_file + + +class FlexibleKVCacheMemoryPool: + def __init__(self, conn, kvcache_shape, kvcache_dtype, device): + self.connection = conn + + if device.startswith("cpu") and G_EnableGPUNicAffinity: + gpu_id = torch.cuda.current_device() + self.device = CPUNicAffinity["cuda:" + str(gpu_id)] + # current memory pool size is 5 times of CPU TensorPoolSize + mempool_size = TensorPoolSize * 5 + else: + self.device = device + mempool_size = TensorPoolSize + + self.kvcache_shape = kvcache_shape + self.kvcache_dtype = kvcache_dtype + + self.kv_cache_numel = 1 + for i in self.kvcache_shape: + self.kv_cache_numel *= i + + self.free_data_addr = set() + self.data_ptr_to_index = dict() + + if self.device.startswith("cpu"): + self.kvcache_mempool = torch.zeros( + (mempool_size,) + kvcache_shape, + dtype=kvcache_dtype, + device=self.device, + pin_memory=True, + ) + else: + self.kvcache_mempool = torch.zeros( + (mempool_size,) + kvcache_shape, dtype=kvcache_dtype, device=self.device + ) + + for i in range(mempool_size): + self.free_data_addr.add(i) + self.data_ptr_to_index[self.kvcache_mempool[i].data_ptr()] = i + + meminfo = eic.MemoryInfo() + meminfo.type = eic.MemoryType.MEMORY_CUDA + meminfo.cuda_id = 0 + vals = eic.IOBuffers() + vals.append( + self.kvcache_mempool.data_ptr(), + self.kvcache_mempool.numel() * self.kvcache_mempool.element_size(), + True, + ) + self.connection.register_memory(vals, meminfo) + logger.info( + f"allocate memory pool, size {self.kvcache_mempool.numel() * self.kvcache_mempool.element_size()}, device {self.device}" + ) + + def try_allocate_kv_cache(self, shape, dtype, count=1): + if len(self.free_data_addr) < count: + return None + + numel = 1 + for i in shape: + numel *= i + if numel != self.kv_cache_numel or dtype != self.kvcache_dtype: + logger.error( + f"allocate from mempool failed, self.kvcache_shape {self.kvcache_shape}, dtype {self.kvcache_dtype}, require shape {shape}, dtype {dtype}" + ) + return None + + ret = [] + for _ in range(count): + free_index = self.free_data_addr.pop() + ret.append(self.kvcache_mempool[free_index]) + return ret + + def free_to_mempool(self, data_ptr): + if data_ptr not in self.data_ptr_to_index: + logger.error( + f"free_to_mempool failed, data_ptr {data_ptr} not in allocated_data_addr" + ) + return + self.free_data_addr.add(self.data_ptr_to_index[data_ptr]) + + def check_data_ptr_allocated(self, data_ptr): + return data_ptr in self.data_ptr_to_index + + def left_count(self): + return len(self.free_data_addr) + + +class EICStorage(HiCacheStorage): + def __init__( + self, hicache_config: HiCacheStorageConfig, memory_pool_host: HostKVCache + ): + global G_EnableKVSetGPUDirect, G_EnableKVGetGPUDirect + global GPUNicAffinity, CPUNicAffinity, G_EnableGPUNicAffinity + + config_file = get_eic_config_file_path() + if os.path.exists(config_file) is False: + logger.error(f"config file {config_file} not exists") + raise RuntimeError(f"eic config file {config_file} not exists") + + with open(config_file, "r") as fin: + config = yaml.safe_load(fin) + + remote_url = config.get("remote_url", None) + if remote_url is None: + AssertionError("remote_url is None") + + endpoint = remote_url[len("eic://") :] + + logger.info(f"eic remote_url:" + remote_url + " endpoint: " + endpoint) + + eic_instance_id = config.get("eic_instance_id", None) + logger.info(f"eic instance_id: {eic_instance_id}") + + eic_thread_num = config.get("eic_thread_num", 1) + logger.info(f"eic thread_num: {eic_thread_num}") + + eic_log_dir = config.get("eic_log_dir", None) + logger.info(f"eic log_dir: {eic_log_dir}") + + eic_log_level = config.get("eic_log_level", 2) + logger.info(f"eic log_level: {eic_log_level}") + + eic_trans_type = config.get("eic_trans_type", 3) + logger.info(f"eic trans_type: {eic_trans_type}") + + eic_flag_file = config.get("eic_flag_file", None) + logger.info(f"eic flag_file: {eic_flag_file}") + + # GDR now is not used + G_EnableKVSetGPUDirect = ( + config.get("enable_kvset_gpu_direct", False) and torch.cuda.is_available() + ) + logger.debug(f"eic enable_kvset_gpu_direct: {G_EnableKVSetGPUDirect}") + + G_EnableKVGetGPUDirect = ( + config.get("enable_kvget_gpu_direct", False) and torch.cuda.is_available() + ) + logger.debug(f"eic enable_kvget_gpu_direct: {G_EnableKVGetGPUDirect}") + + self.model_name = hicache_config.model_name + + # rdma + enable_kv_set_direct = config.get("enable_kvset_direct", True) + logger.info(f"eic enable_kv_set_direct: {enable_kv_set_direct}") + self.enable_kv_set_direct = enable_kv_set_direct + + enable_kv_get_direct = config.get("enable_kvget_direct", True) + logger.info(f"eic enable_kv_get_direct: {enable_kv_get_direct}") + self.enable_kv_get_direct = enable_kv_get_direct + + # gpu nic affinity + G_EnableGPUNicAffinity = config.get("enable_gpu_nic_affinity", False) + logger.info(f"eic enable_gpu_nic_affinity: {G_EnableGPUNicAffinity}") + self.enable_gpu_nic_affinity = G_EnableGPUNicAffinity + + if G_EnableGPUNicAffinity: + if "gpu_nic_affinity_config" in config: + GPUNicAffinity = json.loads(config["gpu_nic_affinity_config"]) + if "cpu_nic_affinity_config" in config: + CPUNicAffinity = json.loads(config["cpu_nic_affinity_config"]) + logger.info(f"eic gpu nic affinity {GPUNicAffinity}") + logger.info(f"eic cpu nic affinity {CPUNicAffinity}") + + eic_namespace = config.get("eic_namespace", "") + logger.info(f"eic namespace: {eic_namespace}") + self.eic_namespace = eic_namespace + + if not os.path.exists(eic_log_dir) and not os.path.isdir(eic_log_dir): + os.makedirs(eic_log_dir, exist_ok=True) + + self.connection = eic.Client() + init_option = eic.InitOption() + init_option.log_dir = eic_log_dir + init_option.log_level = eic.LogLevel(eic_log_level) + init_option.transport_type = eic.TransportType(eic_trans_type) + init_option.flag_file = eic_flag_file + + if G_EnableGPUNicAffinity: + gpu_id = torch.cuda.current_device() + init_option.multi_net_local_interface_names = GPUNicAffinity[ + "cuda:" + str(gpu_id) + ] + logger.info( + f"gpu {gpu_id} set gpu nic affinity to {init_option.multi_net_local_interface_names}" + ) + + ret = self.connection.init(eic_instance_id, endpoint, init_option) + if ret != 0: + logger.error(f"fail to init eic client, ret: {ret}") + raise RuntimeError("EIC Client Init Failed.") + self.warmup() + + self.memory_pool_host = memory_pool_host + self.host_kvcache_layout = self.memory_pool_host.layout + self.trans_type = eic.TransportType(eic_trans_type) + self.kv_cache_dtype = self.memory_pool_host.dtype + self.is_mla_model = hicache_config.is_mla_model + self.rank = hicache_config.tp_rank + self.world_size = hicache_config.tp_size + self.page_size = self.memory_pool_host.page_size + self.use_zero_copy = self.memory_pool_host.layout == "page_first" + if not self.use_zero_copy: + self.kv_cache_shape = self.memory_pool_host.get_data_page( + 0, flat=True + ).shape + if self.enable_kv_set_direct: + self.kv_cache_write_mem_pool = FlexibleKVCacheMemoryPool( + self.connection, self.kv_cache_shape, self.kv_cache_dtype, "cpu" + ) + if self.enable_kv_get_direct: + self.kv_cache_get_mem_pool = FlexibleKVCacheMemoryPool( + self.connection, self.kv_cache_shape, self.kv_cache_dtype, "cpu" + ) + self._init_eic_prefix() + + def warmup(self): + logger.info("begin warm up eic client") + start_time = time.perf_counter() + num_warmup = 1024 + preheat_keys = ["warmup_key_" + str(i) for i in range(num_warmup)] + batch_size = 32 + for i in range(0, num_warmup, batch_size): + keys_vec = eic.StringVector() + for key in preheat_keys[i : i + batch_size]: + keys_vec.append(key) + exist_option = eic.ExistOption() + _, _ = self.connection.mexist(keys_vec, exist_option) + logger.info( + f"finish eic client warm up, warm up cost {time.perf_counter() - start_time:.2f} seconds" + ) + + def register_mem_pool_host(self, memory_pool_host: HostKVCache) -> None: + # no need judge meminfo type, cuda_id, etc. + meminfo = eic.MemoryInfo() + meminfo.type = eic.MemoryType.MEMORY_CUDA + meminfo.cuda_id = 0 + vals = eic.IOBuffers() + buffer = memory_pool_host.kv_buffer + vals.append( + buffer.data_ptr(), + buffer.numel() * buffer.element_size(), + True, + ) + self.connection.register_memory(vals, meminfo) + + def _init_eic_prefix(self): + if self.is_mla_model: + self.eic_prefix = ( + f"{self.model_name}_mla_att_{self.host_kvcache_layout}@sglang" + ) + else: + self.eic_prefix = f"{self.model_name}_mha_attn_{self.host_kvcache_layout}_{self.rank}_{self.world_size}_@sglang" + + def _get_eic_key(self, keys: List[str]) -> str: + return [f"{self.eic_prefix}_{key}" for key in keys] + + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_size: Optional[Any] = None, + ) -> bool: + # now is not used + if self.use_zero_copy: + return self.zero_copy_batch_set([key], [target_location]) + else: + return self.generic_batch_set([key], [value]) + + # target_locations and target_sizes are not used for now + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + if len(keys) == 0: + return True + if self.use_zero_copy: + return self.zero_copy_batch_set(keys, values) + else: + return self.generic_batch_set(keys, values) + + def get( + self, + key, + target_location: Optional[Any] = None, + target_size: Optional[Any] = None, + ) -> torch.Tensor | None: + # now is not used + if self.use_zero_copy: + return self.zero_copy_batch_get([key], [target_location]) + else: + return self.generic_batch_get([key], [target_location]) + + # use for v1 interface, and shound not be called directly + def batch_get( + self, + keys: List[str], + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> List[torch.Tensor | None]: + assert len(keys) == len(target_locations) + if len(keys) == 0: + return None + if self.use_zero_copy: + return self.zero_copy_batch_get(keys, target_locations) + else: + return self.generic_batch_get(keys, target_locations) + + def _batch_exists_impl(self, keys) -> List[bool]: + if len(keys) == 0: + return 0 + eic_keys = self._get_eic_key(keys) + logger.debug(f"eic exists {len(keys)}") + result = [] + exist_bs = 1024 + for i in range(0, len(eic_keys), exist_bs): + batch_keys = eic_keys[i : i + exist_bs] + keys_vec = eic.StringVector() + for key in batch_keys: + keys_vec.append(key) + exist_option = eic.ExistOption() + exist_option.ns = self.eic_namespace + status_code, exist_outcome = self.connection.mexist(keys_vec, exist_option) + if status_code != eic.StatusCode.SUCCESS: + logger.error( + f"eic exists {len(keys)} failed, status_code {status_code}" + ) + result.extend([False] * len(batch_keys)) + for err_code in exist_outcome.status_codes: + result.append(err_code == eic.StatusCode.SUCCESS) + return result + + def exists(self, key) -> bool: + exist_num = self.batch_exists([key]) + return exist_num == 1 + + def batch_exists(self, keys) -> int: + if len(keys) == 0: + return 0 + if self.use_zero_copy and not self.is_mla_model: + keys = self._get_mha_zero_copy_keys(keys) + exist_mask = self._batch_exists_impl(keys) + prefix_success = 0 + for exist in exist_mask: + if exist: + prefix_success += 1 + else: + break + if not self.is_mla_model and self.use_zero_copy: + prefix_success = prefix_success // 2 + return prefix_success + + def delete(self, key) -> None: + eic_keys = self._get_eic_key([key]) + keys_vec = eic.StringVector() + for eic_key in eic_keys: + keys_vec.append(eic_key) + del_option = eic.DelOption() + self.connection.mdel(keys_vec, del_option) + + def clear(self) -> None: + return + + # Not used for now + def _filter_kv_cache(self, total_len) -> Tuple[int, int]: + mean_len = total_len // self.world_size + remainder = total_len % self.world_size + tp_keys_len = mean_len + (1 if self.rank < remainder else 0) + start = self.rank * mean_len + min(self.rank, remainder) + end = start + tp_keys_len + logger.debug(f"start: {start}, end: {end}, tp_keys_len: {tp_keys_len}") + return start, end + + def zero_copy_batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + logger.debug(f"eic zero copy set {len(keys)} keys") + if len(keys) == 0: + return True + eic_keys = self._get_eic_key(keys) + keys_vec = eic.StringVector() + vals_vec = eic.IOBuffers() + # set data key & value + for i, key in enumerate(eic_keys): + # set data key & value + keys_vec.append(key) + vals_vec.append( + values[i].data_ptr(), + values[i].element_size() * values[i].numel(), + True, + ) + # set options + set_option = eic.SetOption() + set_option.ns = self.eic_namespace + set_option.ttl_second = -1 + status_code, set_outcome = self.connection.mset(keys_vec, vals_vec, set_option) + if status_code != eic.StatusCode.SUCCESS: + logger.error(f"eic mset {len(keys)} failed, status_code {status_code}") + return [False] * len(keys) + else: + logger.debug(f"eic zero copy mset {len(keys)} success") + return [True] * len(keys) + + def zero_copy_batch_get( + self, keys: List[str], values: List[torch.Tensor] + ) -> List[bool]: + logger.debug(f"eic zero copy get {len(keys)} keys") + # Get Data: generate data keys and vals + get_data_start_time = time.perf_counter() + eic_keys = self._get_eic_key(keys) + data_keys = eic.StringVector() + data_vals = eic.IOBuffers() + success_mask = [True] * len(keys) + count = len(keys) + for i, key in enumerate(eic_keys): + data_keys.append(key) + data_vals.append( + values[i].data_ptr(), + values[i].element_size() * values[i].numel(), + True, + ) + + # Get data: recv data buffer tensor + get_option = eic.GetOption() + get_option.ns = self.eic_namespace + status_code, data_vals, get_outcome = self.connection.mget( + data_keys, get_option, data_vals + ) + + if status_code != eic.StatusCode.SUCCESS: + if status_code == eic.StatusCode.PARTIAL_FAILED: + for i, err_code in enumerate(get_outcome.status_codes): + success = err_code == eic.StatusCode.SUCCESS + if success: + logger.debug(f"eic get data {eic_keys[i]} success") + else: + logger.error( + f"eic get data {eic_keys[i]} failed, err_code {err_code}" + ) + success_mask[i] = False + else: + logger.error( + f"eic mget {len(eic_keys)} keys failed, status_code {status_code}" + ) + success_mask = [False] * len(keys) + return success_mask + + get_data_end_time = time.perf_counter() + get_data_execution_time = (get_data_end_time - get_data_start_time) * 1e6 + logger.debug(f"eic get {count} keys data cost %.2f us", get_data_execution_time) + return success_mask + + def generic_batch_set( + self, + keys: List[str], + values: List[torch.Tensor], + ) -> List[bool]: + assert len(keys) == len(values) + logger.debug(f"eic generic set {len(keys)} keys") + if len(keys) == 0: + return True + eic_keys = self._get_eic_key(keys) + keys_vec = eic.StringVector() + vals_vec = eic.IOBuffers() + count = len(keys) + registered = False + items = [] + if self.enable_kv_set_direct: + values_data_ptrs = [] + items = self.kv_cache_write_mem_pool.try_allocate_kv_cache( + self.kv_cache_shape, self.kv_cache_dtype, count + ) + if items is None: + logger.warning("can not allocate tensor from pool") + for i, value in enumerate(values): + values_data_ptrs.append( + (value.data_ptr(), value.element_size() * value.numel(), False) + ) + else: + objs = items + registered = True + for i, key in enumerate(eic_keys): + temp = objs[i].reshape(values[i].shape).contiguous() + temp.copy_(values[i]) + if temp.data_ptr() != objs[i].data_ptr(): + registered = False + temp = temp.cpu() + values_data_ptrs.append( + ( + temp.data_ptr(), + temp.element_size() * temp.numel(), + registered, + ) + ) + + for i, key in enumerate(eic_keys): + keys_vec.append(key) + data_ptr, data_size, registered = values_data_ptrs[i] + vals_vec.append(data_ptr, data_size, registered) + else: + # use tensor direct + for i, key in enumerate(eic_keys): + keys_vec.append(key) + vals_vec.append( + values[i].data_ptr(), + values[i].element_size() * values[i].numel(), + False, + ) + + # set options + set_option = eic.SetOption() + set_option.ns = self.eic_namespace + set_option.ttl_second = -1 + status_code, set_outcome = self.connection.mset(keys_vec, vals_vec, set_option) + if status_code != eic.StatusCode.SUCCESS: + logger.error(f"eic mset {len(eic_keys)} failed, status_code {status_code}") + else: + logger.debug(f"eic mset {len(eic_keys)} success") + + if self.enable_kv_set_direct and items is not None: + for item in items: + self.kv_cache_write_mem_pool.free_to_mempool(item.data_ptr()) + + err_code = set_outcome.status_codes[0] + if err_code != eic.StatusCode.SUCCESS: + logger.error(f"set data key {len(eic_keys)} failed, err_code {err_code}") + return [False] * len(keys) + + logger.debug(f"set data key {len(eic_keys)} success") + return [True] * len(keys) + + def generic_batch_get( + self, keys: List[str], buffers: List[torch.Tensor] + ) -> List[bool]: + # all success or all fail + logger.debug(f"eic generic get {len(keys)} keys") + eic_keys = self._get_eic_key(keys) + get_data_start_time = time.perf_counter() + data_keys = eic.StringVector() + data_vals = eic.IOBuffers() + count = len(eic_keys) + registered = False + items = [] + success_mask = [True] * len(keys) + if self.enable_kv_get_direct: + items = self.kv_cache_get_mem_pool.try_allocate_kv_cache( + self.kv_cache_shape, self.kv_cache_dtype, count + ) + if items is None: + logger.warning("can not allocate tensor from pool") + for i, key in enumerate(eic_keys): + data_keys.append(key) + data_vals.append( + buffers[i].data_ptr(), + buffers[i].element_size() * buffers[i].numel(), + False, + ) + else: + registered = True + for i, key in enumerate(eic_keys): + data_keys.append(key) + data_vals.append( + items[i].data_ptr(), + items[i].element_size() * items[i].numel(), + registered, + ) + + else: + for i, key in enumerate(eic_keys): + data_keys.append(key) + data_vals.append( + buffers[i].data_ptr(), + buffers[i].element_size() * buffers[i].numel(), + False, + ) + + # Get data: recv data buffer tensor + get_option = eic.GetOption() + get_option.ns = self.eic_namespace + status_code, data_vals, get_outcome = self.connection.mget( + data_keys, get_option, data_vals + ) + + if status_code != eic.StatusCode.SUCCESS: + if status_code == eic.StatusCode.PARTIAL_FAILED: + for i, err_code in enumerate(get_outcome.status_codes): + success = err_code == eic.StatusCode.SUCCESS + if success: + logger.debug(f"eic get data {eic_keys[i]} success") + else: + logger.error( + f"eic get data {eic_keys[i]} failed, err_code {err_code}" + ) + success_mask[i] = False + else: + logger.error( + f"eic mget {len(eic_keys)} keys failed, status_code {status_code}" + ) + success_mask = [False] * len(keys) + + if registered: + for i, item in enumerate(items): + if success_mask[i]: + buffers[i].copy_(item) + self.kv_cache_get_mem_pool.free_to_mempool(item.data_ptr()) + + get_data_end_time = time.perf_counter() + get_data_execution_time = (get_data_end_time - get_data_start_time) * 1e6 + logger.debug(f"eic get {count} keys data cost %.2f us", get_data_execution_time) + return success_mask + + def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]: + new_keys = [] + for k in keys: + new_keys.append(f"{k}_k") + new_keys.append(f"{k}_v") + return new_keys + + def _get_mha_zero_copy_values( + self, values: List[torch.Tensor] + ) -> List[torch.Tensor]: + new_values = [] + for value in values: + new_values.append(value[0]) + new_values.append(value[1]) + return new_values + + def _batch_get_preprocess(self, keys, host_indices): + page_num = len(host_indices) // self.page_size + # use memory pool directly or dummy page + values = ( + [ + self.memory_pool_host.get_data_page( + host_indices[i * self.page_size], flat=False + ) + for i in range(page_num) + ] + if self.use_zero_copy + else [ + self.memory_pool_host.get_dummy_flat_data_page() + for _ in range(page_num) + ] + ) + + if self.use_zero_copy and not self.is_mla_model: + keys = self._get_mha_zero_copy_keys(keys) + values = self._get_mha_zero_copy_values(values) + + return keys, values + + def _batch_get_postprocess(self, host_indices, values, results): + page_num = len(host_indices) // self.page_size + + if self.use_zero_copy: + if not self.is_mla_model: + results = [ + (results[2 * i] and results[2 * i + 1]) for i in range(page_num) + ] + results = results[:page_num] + return results + + # dummy page copy to host memory pool + for i in range(page_num): + if not results[i]: + break + self.memory_pool_host.set_from_flat_data_page( + host_indices[i * self.memory_pool_host.page_size], values[i] + ) + + return results + + def batch_get_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + keys, values = self._batch_get_preprocess(keys, host_indices) + results = self.batch_get(keys, values) + return self._batch_get_postprocess(host_indices, values, results) + + def _batch_set_preprocess(self, keys, host_indices): + page_num = len(host_indices) // self.page_size + flat = not self.use_zero_copy + values = [ + self.memory_pool_host.get_data_page( + host_indices[i * self.page_size], flat=flat + ) + for i in range(page_num) + ] + + if self.use_zero_copy and not self.is_mla_model: + keys = self._get_mha_zero_copy_keys(keys) + values = self._get_mha_zero_copy_values(values) + + return keys, values + + def batch_set_v1( + self, + keys: List[str], + host_indices: torch.Tensor, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> List[bool]: + keys, values = self._batch_set_preprocess(keys, host_indices) + results = self.batch_set(keys, values) + return results diff --git a/python/sglang/srt/mem_cache/storage/eic/test_unit.py b/python/sglang/srt/mem_cache/storage/eic/test_unit.py new file mode 100644 index 000000000..03d348ad8 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/eic/test_unit.py @@ -0,0 +1,115 @@ +import argparse +import os + +import eic +import torch +import yaml + + +def pase_args(): + parser = argparse.ArgumentParser(description="EIC Storage Unit Test") + parser.add_argument( + "--config", + "-c", + type=str, + default="/sgl-workspace/config/remote-eic.yaml", + help="EIC yaml config", + ) + args, _ = parser.parse_known_args() + return args + + +def init_eic_client(): + args = pase_args() + config_path = os.path.abspath(args.config) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + with open(config_path, "r") as fin: + config = yaml.safe_load(fin) + + remote_url = config.get("remote_url", None) + if remote_url is None: + AssertionError("remote_url is None") + endpoint = remote_url[len("eic://") :] + eic_instance_id = config.get("eic_instance_id", None) + eic_log_dir = config.get("eic_log_dir", None) + eic_log_level = config.get("eic_log_level", 2) + eic_trans_type = config.get("eic_trans_type", 3) + eic_flag_file = config.get("eic_flag_file", None) + + if not os.path.exists(eic_log_dir): + os.makedirs(eic_log_dir, exist_ok=True) + eic_client = eic.Client() + init_option = eic.InitOption() + init_option.log_dir = eic_log_dir + init_option.log_level = eic.LogLevel(eic_log_level) + init_option.transport_type = eic.TransportType(eic_trans_type) + init_option.flag_file = eic_flag_file + ret = eic_client.init(eic_instance_id, endpoint, init_option) + if ret != 0: + raise RuntimeError(f"EIC Client init failed with error code: {ret}") + return eic_client + + +def test_set(eic_client): + test_key = ["test_key_" + str(i) for i in range(16)] + tensors = [ + torch.ones([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu") + for _ in range(16) + ] + data_keys = eic.StringVector() + data_vals = eic.IOBuffers() + for i in range(16): + data_keys.append(test_key[i]) + data_vals.append( + tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False + ) + set_opt = eic.SetOption() + set_opt.ttl_second = 3 + status_code, set_outcome = eic_client.mset(data_keys, data_vals, set_opt) + assert ( + status_code == eic.StatusCode.SUCCESS + ), f"Set failed with status code: {status_code}" + + +def test_get(eic_client): + test_key = ["test_key_" + str(i) for i in range(16)] + tensors = [ + torch.zeros([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu") + for _ in range(16) + ] + data_keys = eic.StringVector() + data_vals = eic.IOBuffers() + for i in range(16): + data_keys.append(test_key[i]) + data_vals.append( + tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False + ) + get_opt = eic.GetOption() + status_code, data_vals, get_outcome = eic_client.mget(data_keys, get_opt, data_vals) + assert ( + status_code == eic.StatusCode.SUCCESS + ), f"Get failed with status code: {status_code}" + + +def test_exists(eic_client): + test_key = ["test_key_" + str(i) for i in range(16)] + data_keys = eic.StringVector() + for key in test_key: + data_keys.append(key) + exists_opt = eic.ExistOption() + status_code, exists_outcome = eic_client.mexist(data_keys, exists_opt) + assert ( + status_code == eic.StatusCode.SUCCESS + ), f"Exists failed with status code: {status_code}" + + +def main(): + eic_client = init_eic_client() + test_set(eic_client) + test_exists(eic_client) + test_get(eic_client) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d647e10ec..d616e7f1a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2260,7 +2260,7 @@ class ServerArgs: parser.add_argument( "--hicache-storage-backend", type=str, - choices=["file", "mooncake", "hf3fs", "nixl", "aibrix", "dynamic"], + choices=["file", "mooncake", "hf3fs", "nixl", "aibrix", "dynamic", "eic"], default=ServerArgs.hicache_storage_backend, help="The storage backend for hierarchical KV cache. " "Built-in backends: file, mooncake, hf3fs, nixl, aibrix. "