[HiCache] Minor fix on file storage backend (#9869)
This commit is contained in:
@@ -648,9 +648,9 @@ class HiCacheController:
|
|||||||
operation.increment(get_result * self.page_size)
|
operation.increment(get_result * self.page_size)
|
||||||
|
|
||||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
dummy_page_dst = [
|
||||||
hash_values
|
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
||||||
)
|
]
|
||||||
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
||||||
if page_data is None:
|
if page_data is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -136,18 +136,19 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
):
|
):
|
||||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||||
|
|
||||||
tp_rank, tp_size, is_mla = (
|
tp_rank, tp_size, model_name = (
|
||||||
storage_config.tp_rank,
|
storage_config.tp_rank,
|
||||||
storage_config.tp_size,
|
storage_config.tp_size,
|
||||||
storage_config.is_mla_model,
|
storage_config.model_name,
|
||||||
)
|
)
|
||||||
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
model_name = "-".join(model_name.split("/")) if model_name else ""
|
||||||
|
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
||||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||||
os.makedirs(self.file_path)
|
os.makedirs(self.file_path)
|
||||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||||
|
|
||||||
def _get_suffixed_key(self, key: str) -> str:
|
def _get_suffixed_key(self, key: str) -> str:
|
||||||
return key + self.tp_suffix
|
return key + self.config_suffix
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
@@ -158,13 +159,11 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
key = self._get_suffixed_key(key)
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
try:
|
try:
|
||||||
# Load directly into target_location's memory buffer
|
expected = target_location.numel() * target_location.element_size()
|
||||||
with open(tensor_path, "rb") as f:
|
with open(tensor_path, "rb", buffering=0) as f:
|
||||||
target_location.set_(
|
buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
|
||||||
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
if f.readinto(buf) != expected:
|
||||||
.reshape(target_location.shape)
|
raise IOError(f"Short read for {key}")
|
||||||
.untyped_storage()
|
|
||||||
)
|
|
||||||
return target_location
|
return target_location
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
||||||
|
|||||||
Reference in New Issue
Block a user