HiCache storage, style change and bug fix (#8719)
This commit is contained in:
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
|
||||
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
||||
"""
|
||||
|
||||
# todo, translate tensor object access for different TP ranks
|
||||
# potentially pass model and TP configs into storage backend
|
||||
# todo, potentially pass model and TP configs into storage backend
|
||||
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
||||
|
||||
@abstractmethod
|
||||
@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_location: torch.Tensor,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> torch.Tensor | None:
|
||||
key = self._get_suffixed_key(key)
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
if target_location is not None:
|
||||
# Load directly into target_location's memory buffer
|
||||
with open(tensor_path, "rb") as f:
|
||||
target_location.set_(
|
||||
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
||||
.reshape(target_location.shape)
|
||||
.storage()
|
||||
)
|
||||
return target_location
|
||||
else:
|
||||
loaded_tensor = torch.load(tensor_path)
|
||||
if isinstance(loaded_tensor, torch.Tensor):
|
||||
return loaded_tensor
|
||||
else:
|
||||
logger.error(f"Loaded data for key {key} is not a tensor.")
|
||||
return None
|
||||
# Load directly into target_location's memory buffer
|
||||
with open(tensor_path, "rb") as f:
|
||||
target_location.set_(
|
||||
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
||||
.reshape(target_location.shape)
|
||||
.untyped_storage()
|
||||
)
|
||||
return target_location
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
||||
return None
|
||||
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[Any] = None,
|
||||
target_locations: List[torch.Tensor],
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
return [
|
||||
@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
|
||||
logger.debug(f"Key {key} already exists. Skipped.")
|
||||
return True
|
||||
try:
|
||||
torch.save(value, tensor_path)
|
||||
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save tensor {key}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user