[HiCache] Minor fix on file storage backend (#9869)
This commit is contained in:
@@ -648,9 +648,9 @@ class HiCacheController:
|
||||
operation.increment(get_result * self.page_size)
|
||||
|
||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||
hash_values
|
||||
)
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
||||
]
|
||||
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
||||
if page_data is None:
|
||||
return
|
||||
|
||||
@@ -136,18 +136,19 @@ class HiCacheFile(HiCacheStorage):
|
||||
):
|
||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||
|
||||
tp_rank, tp_size, is_mla = (
|
||||
tp_rank, tp_size, model_name = (
|
||||
storage_config.tp_rank,
|
||||
storage_config.tp_size,
|
||||
storage_config.is_mla_model,
|
||||
storage_config.model_name,
|
||||
)
|
||||
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
||||
model_name = "-".join(model_name.split("/")) if model_name else ""
|
||||
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||
os.makedirs(self.file_path)
|
||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||
|
||||
def _get_suffixed_key(self, key: str) -> str:
|
||||
return key + self.tp_suffix
|
||||
return key + self.config_suffix
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -158,13 +159,11 @@ class HiCacheFile(HiCacheStorage):
|
||||
key = self._get_suffixed_key(key)
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
# Load directly into target_location's memory buffer
|
||||
with open(tensor_path, "rb") as f:
|
||||
target_location.set_(
|
||||
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
||||
.reshape(target_location.shape)
|
||||
.untyped_storage()
|
||||
)
|
||||
expected = target_location.numel() * target_location.element_size()
|
||||
with open(tensor_path, "rb", buffering=0) as f:
|
||||
buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
|
||||
if f.readinto(buf) != expected:
|
||||
raise IOError(f"Short read for {key}")
|
||||
return target_location
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
||||
|
||||
Reference in New Issue
Block a user