[HiCacheStorage] backup optimization for MLA model (#8865)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -238,13 +240,14 @@ class HiCacheController:
|
|||||||
self.io_backend = io_backend
|
self.io_backend = io_backend
|
||||||
|
|
||||||
self.enable_storage = False
|
self.enable_storage = False
|
||||||
|
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
|
||||||
# todo: move backend initialization to storage backend module
|
# todo: move backend initialization to storage backend module
|
||||||
if storage_backend is not None:
|
if storage_backend is not None:
|
||||||
self.storage_backend_type = storage_backend
|
self.storage_backend_type = storage_backend
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||||
|
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
self.storage_backend = HiCacheFile()
|
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
|
||||||
self.get_hash_str = get_hash_str
|
self.get_hash_str = get_hash_str
|
||||||
elif storage_backend == "nixl":
|
elif storage_backend == "nixl":
|
||||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||||
@@ -257,12 +260,11 @@ class HiCacheController:
|
|||||||
get_hash_str_mooncake,
|
get_hash_str_mooncake,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.storage_backend = MooncakeStore()
|
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
|
||||||
self.get_hash_str = get_hash_str_mooncake
|
self.get_hash_str = get_hash_str_mooncake
|
||||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
assert self.mem_pool_host.layout == "page_first"
|
assert self.mem_pool_host.layout == "page_first"
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||||
HiCacheHF3FS,
|
HiCacheHF3FS,
|
||||||
)
|
)
|
||||||
@@ -399,6 +401,15 @@ class HiCacheController:
|
|||||||
self.prefetch_thread.start()
|
self.prefetch_thread.start()
|
||||||
self.backup_thread.start()
|
self.backup_thread.start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backup_skip(self):
|
||||||
|
return (
|
||||||
|
self.is_mla
|
||||||
|
and get_tensor_model_parallel_rank() != 0
|
||||||
|
# todo: only support file and mooncake
|
||||||
|
and self.storage_backend_type in ["file", "mooncake"]
|
||||||
|
)
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
device_indices: torch.Tensor,
|
device_indices: torch.Tensor,
|
||||||
@@ -809,17 +820,20 @@ class HiCacheController:
|
|||||||
if operation is None:
|
if operation is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.is_mooncake_backend():
|
if not self.backup_skip:
|
||||||
self.mooncake_page_backup(operation)
|
if self.is_mooncake_backend():
|
||||||
elif self.storage_backend_type == "hf3fs":
|
self.mooncake_page_backup(operation)
|
||||||
if self.mem_pool_host.layout == "page_first":
|
elif self.storage_backend_type == "hf3fs":
|
||||||
self.zerocopy_page_backup(operation, batch_size=128)
|
if self.mem_pool_host.layout == "page_first":
|
||||||
elif self.mem_pool_host.layout == "layer_first":
|
self.zerocopy_page_backup(operation, batch_size=128)
|
||||||
self.generic_page_backup(operation, batch_size=128)
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
|
self.generic_page_backup(operation, batch_size=128)
|
||||||
|
else:
|
||||||
|
self.generic_page_backup(operation)
|
||||||
|
min_completed_tokens = operation.completed_tokens
|
||||||
else:
|
else:
|
||||||
self.generic_page_backup(operation)
|
min_completed_tokens = len(operation.token_ids)
|
||||||
|
|
||||||
min_completed_tokens = operation.completed_tokens
|
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
completed_tokens_tensor = torch.tensor(
|
completed_tokens_tensor = torch.tensor(
|
||||||
min_completed_tokens, dtype=torch.int
|
min_completed_tokens, dtype=torch.int
|
||||||
|
|||||||
@@ -101,11 +101,11 @@ class HiCacheStorage(ABC):
|
|||||||
|
|
||||||
class HiCacheFile(HiCacheStorage):
|
class HiCacheFile(HiCacheStorage):
|
||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache"):
|
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
|
||||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
|
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
||||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||||
os.makedirs(self.file_path)
|
os.makedirs(self.file_path)
|
||||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from functools import wraps
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||||
from sglang.srt.utils import is_npu
|
from sglang.srt.utils import is_npu
|
||||||
|
|
||||||
@@ -487,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
ptr_list.append(k_ptr)
|
ptr_list.append(k_ptr)
|
||||||
ptr_list.append(v_ptr)
|
ptr_list.append(v_ptr)
|
||||||
key_ = keys[index // self.page_size]
|
key_ = keys[index // self.page_size]
|
||||||
key_list.append(f"{key_}_k")
|
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
|
||||||
key_list.append(f"{key_}_v")
|
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
|
||||||
element_size = (
|
element_size = (
|
||||||
self.layer_num
|
self.layer_num
|
||||||
* self.dtype.itemsize
|
* self.dtype.itemsize
|
||||||
|
|||||||
@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
||||||
local_rank = get_tensor_model_parallel_rank()
|
|
||||||
prefix_str = ""
|
prefix_str = ""
|
||||||
if prior_hash:
|
if prior_hash:
|
||||||
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
||||||
current_token_ids_bytes = np.array(token_ids).tobytes()
|
current_token_ids_bytes = np.array(token_ids).tobytes()
|
||||||
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
||||||
current_hash_hex = current_hash_object.hexdigest()
|
current_hash_hex = current_hash_object.hexdigest()
|
||||||
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -97,7 +96,7 @@ class MooncakeStoreConfig:
|
|||||||
|
|
||||||
|
|
||||||
class MooncakeStore(HiCacheStorage):
|
class MooncakeStore(HiCacheStorage):
|
||||||
def __init__(self):
|
def __init__(self, is_mla: bool = False):
|
||||||
try:
|
try:
|
||||||
from mooncake.store import MooncakeDistributedStore
|
from mooncake.store import MooncakeDistributedStore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
logger.info("Connect to Mooncake store successfully.")
|
logger.info("Connect to Mooncake store successfully.")
|
||||||
self.warmup()
|
self.warmup()
|
||||||
logger.info("Mooncake store warmup successfully.")
|
logger.info("Mooncake store warmup successfully.")
|
||||||
|
self.is_mla = is_mla
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error("Configuration loading failed: %s", e)
|
logger.error("Configuration loading failed: %s", e)
|
||||||
@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
|
|
||||||
def exists(self, keys) -> bool | dict:
|
def exists(self, keys) -> bool | dict:
|
||||||
_keys = []
|
_keys = []
|
||||||
|
local_rank = get_tensor_model_parallel_rank()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key is None:
|
if key is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
_keys.append(f"{key}_k")
|
if self.is_mla:
|
||||||
|
_keys.append(f"{key}_k")
|
||||||
|
else:
|
||||||
|
_keys.append(f"{key}_{local_rank}_k")
|
||||||
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user