diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index ca441b9f6..8acce8ac7 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -648,9 +648,9 @@ class HiCacheController: operation.increment(get_result * self.page_size) def _generic_page_get(self, operation, hash_values, host_indices): - dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len( - hash_values - ) + dummy_page_dst = [ + self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values + ] page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst) if page_data is None: return diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 1d3ed5ae9..2487910e1 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -136,18 +136,19 @@ class HiCacheFile(HiCacheStorage): ): self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) - tp_rank, tp_size, is_mla = ( + tp_rank, tp_size, model_name = ( storage_config.tp_rank, storage_config.tp_size, - storage_config.is_mla_model, + storage_config.model_name, ) - self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" + model_name = "-".join(model_name.split("/")) if model_name else "" + self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}" if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") def _get_suffixed_key(self, key: str) -> str: - return key + self.tp_suffix + return key + self.config_suffix def get( self, @@ -158,13 +159,11 @@ class HiCacheFile(HiCacheStorage): key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: - # Load directly into target_location's memory buffer - with open(tensor_path, "rb") as f: - target_location.set_( - torch.frombuffer(f.read(), dtype=target_location.dtype) - .reshape(target_location.shape) - .untyped_storage() - ) + expected = target_location.numel() * target_location.element_size() + with open(tensor_path, "rb", buffering=0) as f: + buf = memoryview(target_location.view(torch.uint8).contiguous().numpy()) + if f.readinto(buf) != expected: + raise IOError(f"Short read for {key}") return target_location except FileNotFoundError: logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")