Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -250,17 +251,33 @@ class HiCacheController:
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||
if self.tp_world_size > 1:
|
||||
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
self.prefetch_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
self.backup_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
rank, bytes_per_page, dtype
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
@@ -522,8 +539,8 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
for h in operation.hash_value:
|
||||
page_data = self.storage_backend.get(h)
|
||||
page_datas = self.storage_backend.batch_get(operation.hash_value)
|
||||
for h, page_data in zip(operation.hash_value, page_datas):
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||
@@ -531,7 +548,9 @@ class HiCacheController:
|
||||
break
|
||||
if operation.increment(self.page_size):
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[operation.completed_tokens],
|
||||
operation.host_indices[
|
||||
operation.completed_tokens - self.page_size
|
||||
],
|
||||
page_data,
|
||||
)
|
||||
else:
|
||||
@@ -583,7 +602,7 @@ class HiCacheController:
|
||||
torch.distributed.all_reduce(
|
||||
storage_hit_count_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
group=self.prefetch_tp_group,
|
||||
)
|
||||
storage_hit_count = storage_hit_count_tensor.item()
|
||||
|
||||
@@ -635,21 +654,23 @@ class HiCacheController:
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_backup = operation.token_ids
|
||||
|
||||
last_hashes, data_pages = [], []
|
||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_backup[i : i + self.page_size], last_hash
|
||||
)
|
||||
success = self.storage_backend.set(
|
||||
last_hash,
|
||||
self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
),
|
||||
data_page = self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {last_hash} to storage.")
|
||||
break
|
||||
operation.completed_tokens += self.page_size
|
||||
operation.hash_value.append(last_hash)
|
||||
last_hashes.append(last_hash)
|
||||
data_pages.append(data_page)
|
||||
|
||||
success = self.storage_backend.batch_set(last_hashes, data_pages)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {last_hashes} to storage.")
|
||||
else:
|
||||
operation.completed_tokens += len(tokens_to_backup)
|
||||
operation.hash_value.extend(last_hashes)
|
||||
|
||||
min_completed_tokens = operation.completed_tokens
|
||||
if self.tp_world_size > 1:
|
||||
@@ -659,7 +680,7 @@ class HiCacheController:
|
||||
torch.distributed.all_reduce(
|
||||
completed_tokens_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
group=self.backup_tp_group,
|
||||
)
|
||||
min_completed_tokens = completed_tokens_tensor.item()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user