feat(hicache-3fs): 3FS-Store Backup Optimizations For MLA Model. (#9692)
This commit is contained in:
@@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
entries: int,
|
entries: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
metadata_client: Hf3fsMetadataInterface,
|
metadata_client: Hf3fsMetadataInterface,
|
||||||
|
is_mla_model: bool = False,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
@@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.metadata_client = metadata_client
|
self.metadata_client = metadata_client
|
||||||
|
self.is_mla_model = is_mla_model
|
||||||
self.numel = self.bytes_per_page // self.dtype.itemsize
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
||||||
self.num_pages = self.file_size // self.bytes_per_page
|
self.num_pages = self.file_size // self.bytes_per_page
|
||||||
|
self.skip_backup = False
|
||||||
|
if self.is_mla_model and self.rank != 0:
|
||||||
|
self.skip_backup = True
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
||||||
@@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||||
|
|
||||||
# Choose metadata client based on configuration
|
# Choose metadata client based on configuration
|
||||||
|
is_mla_model = False
|
||||||
if "metadata_server_url" in config and config["metadata_server_url"]:
|
if "metadata_server_url" in config and config["metadata_server_url"]:
|
||||||
# Use global metadata client to connect to metadata server
|
# Use global metadata client to connect to metadata server
|
||||||
metadata_server_url = config["metadata_server_url"]
|
metadata_server_url = config["metadata_server_url"]
|
||||||
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
||||||
|
|
||||||
|
# Enable MLA optimization only when using the global metadata client
|
||||||
|
is_mla_model = storage_config.is_mla_model if storage_config else False
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using global metadata client with server url: {metadata_server_url}"
|
f"Using global metadata client with server url: {metadata_server_url}"
|
||||||
)
|
)
|
||||||
@@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
return HiCacheHF3FS(
|
return HiCacheHF3FS(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
# Let all ranks use the same file path for MLA model
|
||||||
|
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
|
||||||
file_size=int(config["file_size"]),
|
file_size=int(config["file_size"]),
|
||||||
numjobs=int(config["numjobs"]),
|
numjobs=int(config["numjobs"]),
|
||||||
bytes_per_page=bytes_per_page,
|
bytes_per_page=bytes_per_page,
|
||||||
entries=int(config["entries"]),
|
entries=int(config["entries"]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
metadata_client=metadata_client,
|
metadata_client=metadata_client,
|
||||||
|
is_mla_model=is_mla_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
@@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
target_locations: Optional[Any] = None,
|
target_locations: Optional[Any] = None,
|
||||||
target_sizes: Optional[Any] = None,
|
target_sizes: Optional[Any] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
# In MLA backend, only one rank needs to backup the KV cache
|
||||||
|
if self.skip_backup:
|
||||||
|
return True
|
||||||
|
|
||||||
# Todo: Add prefix block's hash key
|
# Todo: Add prefix block's hash key
|
||||||
key_with_prefix = [(key, "") for key in keys]
|
key_with_prefix = [(key, "") for key in keys]
|
||||||
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
||||||
@@ -363,16 +378,21 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
return all(results)
|
return all(results)
|
||||||
|
|
||||||
@synchronized()
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
self.metadata_client.delete_keys(self.rank, [key])
|
self.metadata_client.delete_keys(self.rank, [key])
|
||||||
|
|
||||||
@synchronized()
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
result = self.metadata_client.exists(self.rank, [key])
|
result = self.metadata_client.exists(self.rank, [key])
|
||||||
return result[0] if result else False
|
return result[0] if result else False
|
||||||
|
|
||||||
@synchronized()
|
def batch_exists(self, keys: List[str]) -> int:
|
||||||
|
results = self.metadata_client.exists(self.rank, keys)
|
||||||
|
for i in range(len(keys)):
|
||||||
|
if not results[i]:
|
||||||
|
return i
|
||||||
|
|
||||||
|
return len(keys)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.metadata_client.clear(self.rank)
|
self.metadata_client.clear(self.rank)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user