feat(hicache-3fs): 3FS-Store Backup Optimizations For MLA Model. (#9692)
This commit is contained in:
@@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
entries: int,
|
||||
dtype: torch.dtype,
|
||||
metadata_client: Hf3fsMetadataInterface,
|
||||
is_mla_model: bool = False,
|
||||
):
|
||||
self.rank = rank
|
||||
self.file_path = file_path
|
||||
@@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
self.entries = entries
|
||||
self.dtype = dtype
|
||||
self.metadata_client = metadata_client
|
||||
|
||||
self.is_mla_model = is_mla_model
|
||||
self.numel = self.bytes_per_page // self.dtype.itemsize
|
||||
self.num_pages = self.file_size // self.bytes_per_page
|
||||
self.skip_backup = False
|
||||
if self.is_mla_model and self.rank != 0:
|
||||
self.skip_backup = True
|
||||
self.rank = 0
|
||||
|
||||
logger.info(
|
||||
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
||||
@@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||
|
||||
# Choose metadata client based on configuration
|
||||
is_mla_model = False
|
||||
if "metadata_server_url" in config and config["metadata_server_url"]:
|
||||
# Use global metadata client to connect to metadata server
|
||||
metadata_server_url = config["metadata_server_url"]
|
||||
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
||||
|
||||
# Enable MLA optimization only when using the global metadata client
|
||||
is_mla_model = storage_config.is_mla_model if storage_config else False
|
||||
logger.info(
|
||||
f"Using global metadata client with server url: {metadata_server_url}"
|
||||
)
|
||||
@@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
return HiCacheHF3FS(
|
||||
rank=rank,
|
||||
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
||||
# Let all ranks use the same file path for MLA model
|
||||
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
|
||||
file_size=int(config["file_size"]),
|
||||
numjobs=int(config["numjobs"]),
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=int(config["entries"]),
|
||||
dtype=dtype,
|
||||
metadata_client=metadata_client,
|
||||
is_mla_model=is_mla_model,
|
||||
)
|
||||
|
||||
def get(
|
||||
@@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
# In MLA backend, only one rank needs to backup the KV cache
|
||||
if self.skip_backup:
|
||||
return True
|
||||
|
||||
# Todo: Add prefix block's hash key
|
||||
key_with_prefix = [(key, "") for key in keys]
|
||||
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
||||
@@ -363,16 +378,21 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
return all(results)
|
||||
|
||||
@synchronized()
|
||||
def delete(self, key: str) -> None:
|
||||
self.metadata_client.delete_keys(self.rank, [key])
|
||||
|
||||
@synchronized()
|
||||
def exists(self, key: str) -> bool:
|
||||
result = self.metadata_client.exists(self.rank, [key])
|
||||
return result[0] if result else False
|
||||
|
||||
@synchronized()
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
results = self.metadata_client.exists(self.rank, keys)
|
||||
for i in range(len(keys)):
|
||||
if not results[i]:
|
||||
return i
|
||||
|
||||
return len(keys)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.metadata_client.clear(self.rank)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user