[HiCacheStorage] backup optimization for MLA model (#8865)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
huangtingwei
2025-08-22 18:03:51 +08:00
committed by GitHub
parent 70cf4abccc
commit 6078d5fcc0
4 changed files with 39 additions and 20 deletions

View File

@@ -26,6 +26,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
logger = logging.getLogger(__name__)
@@ -238,13 +240,14 @@ class HiCacheController:
self.io_backend = io_backend
self.enable_storage = False
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
# todo: move backend initialization to storage backend module
if storage_backend is not None:
self.storage_backend_type = storage_backend
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
self.get_hash_str = get_hash_str
elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
@@ -257,12 +260,11 @@ class HiCacheController:
get_hash_str_mooncake,
)
self.storage_backend = MooncakeStore()
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
self.get_hash_str = get_hash_str_mooncake
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs":
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS,
)
@@ -399,6 +401,15 @@ class HiCacheController:
self.prefetch_thread.start()
self.backup_thread.start()
@property
def backup_skip(self):
return (
self.is_mla
and get_tensor_model_parallel_rank() != 0
# todo: only support file and mooncake
and self.storage_backend_type in ["file", "mooncake"]
)
def write(
self,
device_indices: torch.Tensor,
@@ -809,17 +820,20 @@ class HiCacheController:
if operation is None:
continue
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_backup(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_backup(operation, batch_size=128)
if not self.backup_skip:
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_backup(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)
min_completed_tokens = operation.completed_tokens
else:
self.generic_page_backup(operation)
min_completed_tokens = len(operation.token_ids)
min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int