HiCache Storage TP Refinement (#8307)
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
This commit is contained in:
@@ -219,6 +219,7 @@ class HiCacheController:
|
|||||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
mem_pool_host: HostKVCache,
|
mem_pool_host: HostKVCache,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
tp_group: torch.distributed.ProcessGroup,
|
||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event = None,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
io_backend: str = "",
|
io_backend: str = "",
|
||||||
@@ -244,11 +245,17 @@ class HiCacheController:
|
|||||||
self.enable_storage = False
|
self.enable_storage = False
|
||||||
# todo: move backend initialization to storage backend module
|
# todo: move backend initialization to storage backend module
|
||||||
if storage_backend is not None:
|
if storage_backend is not None:
|
||||||
|
# create a new communication group for synchronizing storage operations across TP workers
|
||||||
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||||
|
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||||
|
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
self.storage_backend = HiCacheFile()
|
self.storage_backend = HiCacheFile()
|
||||||
self.enable_storage = True
|
self.enable_storage = True
|
||||||
# todo: threshold policy for prefetching
|
# todo: threshold policy for prefetching
|
||||||
self.prefetch_threshold = prefetch_threshold
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported storage backend: {storage_backend}"
|
f"Unsupported storage backend: {storage_backend}"
|
||||||
@@ -568,13 +575,32 @@ class HiCacheController:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
storage_hit_count_tensor = torch.tensor(
|
||||||
|
storage_hit_count, dtype=torch.int
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
storage_hit_count_tensor,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_group,
|
||||||
|
)
|
||||||
|
storage_hit_count = storage_hit_count_tensor.item()
|
||||||
|
|
||||||
if storage_hit_count < self.prefetch_threshold:
|
if storage_hit_count < self.prefetch_threshold:
|
||||||
# not to prefetch if not enough benefits
|
# not to prefetch if not enough benefits
|
||||||
self.prefetch_revoke_queue.put(operation.request_id)
|
self.prefetch_revoke_queue.put(operation.request_id)
|
||||||
else:
|
|
||||||
operation.hash_value = hash_value
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
operation.hash_value = hash_value[
|
||||||
|
: (storage_hit_count // self.page_size)
|
||||||
|
]
|
||||||
|
# free the pre-allocated memory for pages that are not hit
|
||||||
|
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
|
||||||
|
operation.host_indices = operation.host_indices[:storage_hit_count]
|
||||||
|
logger.debug(
|
||||||
|
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
||||||
)
|
)
|
||||||
self.prefetch_buffer.put(operation)
|
self.prefetch_buffer.put(operation)
|
||||||
|
|
||||||
@@ -611,17 +637,37 @@ class HiCacheController:
|
|||||||
last_hash = get_hash_str(
|
last_hash = get_hash_str(
|
||||||
tokens_to_backup[i : i + self.page_size], last_hash
|
tokens_to_backup[i : i + self.page_size], last_hash
|
||||||
)
|
)
|
||||||
# todo, handle failures in storage backend
|
success = self.storage_backend.set(
|
||||||
self.storage_backend.set(
|
|
||||||
last_hash,
|
last_hash,
|
||||||
self.mem_pool_host.get_flat_data_page(
|
self.mem_pool_host.get_flat_data_page(
|
||||||
operation.host_indices[i]
|
operation.host_indices[i]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Failed to write page {last_hash} to storage.")
|
||||||
|
break
|
||||||
operation.completed_tokens += self.page_size
|
operation.completed_tokens += self.page_size
|
||||||
operation.hash_value.append(last_hash)
|
operation.hash_value.append(last_hash)
|
||||||
|
|
||||||
self.ack_backup_queue.put((operation.id, operation.hash_value))
|
min_completed_tokens = operation.completed_tokens
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
completed_tokens_tensor = torch.tensor(
|
||||||
|
min_completed_tokens, dtype=torch.int
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
completed_tokens_tensor,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_group,
|
||||||
|
)
|
||||||
|
min_completed_tokens = completed_tokens_tensor.item()
|
||||||
|
|
||||||
|
self.ack_backup_queue.put(
|
||||||
|
(
|
||||||
|
operation.id,
|
||||||
|
operation.hash_value[: min_completed_tokens // self.page_size],
|
||||||
|
min_completed_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ import torch
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
||||||
hasher = hashlib.sha256()
|
hasher = hashlib.sha256()
|
||||||
|
|
||||||
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache"):
|
def __init__(self, file_path: str = "/tmp/hicache"):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
if not os.path.exists(self.file_path):
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
|
||||||
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||||
os.makedirs(self.file_path)
|
os.makedirs(self.file_path)
|
||||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||||
|
|
||||||
|
def _get_suffixed_key(self, key: str) -> str:
|
||||||
|
return key + self.tp_suffix
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
try:
|
try:
|
||||||
# todo: fixing the target_location logic to enable in-place loading
|
# todo: fixing the target_location logic to enable in-place loading
|
||||||
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||||
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
if self.exists(key):
|
if self.exists(key):
|
||||||
logger.debug(f"Key {key} already exists. Skipped.")
|
logger.debug(f"Key {key} already exists. Skipped.")
|
||||||
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
return os.path.exists(tensor_path)
|
return os.path.exists(tensor_path)
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
|
key = self._get_suffixed_key(key)
|
||||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
try:
|
try:
|
||||||
os.remove(tensor_path)
|
os.remove(tensor_path)
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
|
|||||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||||
|
|
||||||
self.tp_group = tp_cache_group
|
self.tp_group = tp_cache_group
|
||||||
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
||||||
self.enable_storage = hicache_storage_backend is not None
|
self.enable_storage = hicache_storage_backend is not None
|
||||||
# todo: customizable storage prefetch threshold
|
# todo: customizable storage prefetch threshold
|
||||||
self.prefetch_threshold = 256
|
self.prefetch_threshold = 256
|
||||||
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
|
|||||||
token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator,
|
||||||
self.token_to_kv_pool_host,
|
self.token_to_kv_pool_host,
|
||||||
page_size,
|
page_size,
|
||||||
|
self.tp_group,
|
||||||
load_cache_event=self.load_cache_event,
|
load_cache_event=self.load_cache_event,
|
||||||
write_policy=hicache_write_policy,
|
write_policy=hicache_write_policy,
|
||||||
io_backend=hicache_io_backend,
|
io_backend=hicache_io_backend,
|
||||||
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
|
|||||||
queue_size = torch.tensor(
|
queue_size = torch.tensor(
|
||||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||||
)
|
)
|
||||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
if self.tp_world_size > 1:
|
||||||
# synchrnoize TP workers to make the same update to radix cache
|
# synchrnoize TP workers to make the same update to radix cache
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
queue_size,
|
queue_size,
|
||||||
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
|
|||||||
queue_size = torch.tensor(
|
queue_size = torch.tensor(
|
||||||
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||||
)
|
)
|
||||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
if self.tp_world_size > 1:
|
||||||
# synchrnoize TP workers to make the same update to hiradix cache
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
queue_size,
|
queue_size,
|
||||||
@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
|
|||||||
queue_size = torch.tensor(
|
queue_size = torch.tensor(
|
||||||
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||||
)
|
)
|
||||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
if self.tp_world_size > 1:
|
||||||
# synchrnoize TP workers to make the same update to hiradix cache
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
queue_size,
|
queue_size,
|
||||||
@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
|
|||||||
group=self.tp_group,
|
group=self.tp_group,
|
||||||
)
|
)
|
||||||
for _ in range(queue_size.item()):
|
for _ in range(queue_size.item()):
|
||||||
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
|
ack_id, hash_value, completed_tokens = (
|
||||||
self.ongoing_backup[ack_id].hash_value = hash_value
|
self.cache_controller.ack_backup_queue.get()
|
||||||
self.ongoing_backup[ack_id].release_host()
|
)
|
||||||
|
host_node = self.ongoing_backup[ack_id]
|
||||||
|
if completed_tokens < len(host_node.key):
|
||||||
|
# backup is only partially successful, split the node
|
||||||
|
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
||||||
|
new_node.hash_value = hash_value
|
||||||
|
host_node.release_host()
|
||||||
del self.ongoing_backup[ack_id]
|
del self.ongoing_backup[ack_id]
|
||||||
|
|
||||||
def check_prefetch_progress(self, req_id: str):
|
def check_prefetch_progress(self, req_id: str):
|
||||||
@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
|
|||||||
)
|
)
|
||||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||||
|
|
||||||
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
|
min_completed_tokens = completed_tokens
|
||||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
if self.tp_world_size > 1:
|
||||||
# synchrnoize TP workers to make the same update to hiradix cache
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
|
completed_tokens_tensor = torch.tensor(
|
||||||
|
min_completed_tokens, dtype=torch.int
|
||||||
|
)
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
min_completed_tokens,
|
completed_tokens_tensor,
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
group=self.tp_group,
|
group=self.tp_group,
|
||||||
)
|
)
|
||||||
min_completed_tokens = min_completed_tokens.item()
|
min_completed_tokens = completed_tokens_tensor.item()
|
||||||
fetched_token_ids = token_ids[:min_completed_tokens]
|
fetched_token_ids = token_ids[:min_completed_tokens]
|
||||||
written_indices = host_indices[:min_completed_tokens]
|
written_indices = host_indices[:min_completed_tokens]
|
||||||
matched_length = self._insert_helper_host(
|
matched_length = self._insert_helper_host(
|
||||||
@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
|
|||||||
new_input_tokens: List[int],
|
new_input_tokens: List[int],
|
||||||
last_hash: Optional[str] = None,
|
last_hash: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
|
# align the number of fetching tokens to the page size
|
||||||
|
prefetch_length = len(new_input_tokens) - (
|
||||||
|
len(new_input_tokens) % self.page_size
|
||||||
|
)
|
||||||
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
||||||
|
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
|
||||||
return
|
return
|
||||||
|
|
||||||
last_host_node.protect_host()
|
last_host_node.protect_host()
|
||||||
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
||||||
if host_indices is None:
|
if host_indices is None:
|
||||||
self.evict_host(len(new_input_tokens))
|
self.evict_host(prefetch_length)
|
||||||
host_indices = self.cache_controller.mem_pool_host.alloc(
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
||||||
len(new_input_tokens)
|
|
||||||
)
|
|
||||||
if host_indices is None:
|
if host_indices is None:
|
||||||
last_host_node.release_host()
|
last_host_node.release_host()
|
||||||
# no sufficient host memory to prefetch
|
# no sufficient host memory to prefetch
|
||||||
|
|||||||
@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
|
|||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def alloc(self, need_size: int) -> torch.Tensor:
|
def alloc(self, need_size: int) -> torch.Tensor:
|
||||||
|
assert (
|
||||||
|
need_size % self.page_size == 0
|
||||||
|
), "The requested size should be a multiple of the page size."
|
||||||
if need_size > self.available_size():
|
if need_size > self.available_size():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user