HiCache storage, style change and bug fix (#8719)
This commit is contained in:
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
|
|||||||
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# todo, translate tensor object access for different TP ranks
|
# todo, potentially pass model and TP configs into storage backend
|
||||||
# potentially pass model and TP configs into storage backend
|
|
||||||
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
target_location: Optional[Any] = None,
|
target_location: torch.Tensor,
|
||||||
target_sizes: Optional[Any] = None,
|
target_sizes: Optional[Any] = None,
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
key = self._get_suffixed_key(key)
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
try:
|
try:
|
||||||
if target_location is not None:
|
|
||||||
# Load directly into target_location's memory buffer
|
# Load directly into target_location's memory buffer
|
||||||
with open(tensor_path, "rb") as f:
|
with open(tensor_path, "rb") as f:
|
||||||
target_location.set_(
|
target_location.set_(
|
||||||
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
||||||
.reshape(target_location.shape)
|
.reshape(target_location.shape)
|
||||||
.storage()
|
.untyped_storage()
|
||||||
)
|
)
|
||||||
return target_location
|
return target_location
|
||||||
else:
|
|
||||||
loaded_tensor = torch.load(tensor_path)
|
|
||||||
if isinstance(loaded_tensor, torch.Tensor):
|
|
||||||
return loaded_tensor
|
|
||||||
else:
|
|
||||||
logger.error(f"Loaded data for key {key} is not a tensor.")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def batch_get(
|
def batch_get(
|
||||||
self,
|
self,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
target_locations: Optional[Any] = None,
|
target_locations: List[torch.Tensor],
|
||||||
target_sizes: Optional[Any] = None,
|
target_sizes: Optional[Any] = None,
|
||||||
) -> List[torch.Tensor | None]:
|
) -> List[torch.Tensor | None]:
|
||||||
return [
|
return [
|
||||||
@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
logger.debug(f"Key {key} already exists. Skipped.")
|
logger.debug(f"Key {key} already exists. Skipped.")
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
torch.save(value, tensor_path)
|
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save tensor {key}: {e}")
|
logger.error(f"Failed to save tensor {key}: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user