Simple prefetch policy (#8692)
This commit is contained in:
@@ -20,6 +20,8 @@ from sglang.bench_serving import (
|
|||||||
sample_random_requests,
|
sample_random_requests,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -139,7 +141,7 @@ async def async_request_sglang_generate(
|
|||||||
"""
|
"""
|
||||||
Sends a streaming request to the server. Gathers text token-by-token.
|
Sends a streaming request to the server. Gathers text token-by-token.
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
headers = {}
|
headers = {}
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
@@ -150,6 +152,8 @@ async def async_request_sglang_generate(
|
|||||||
try:
|
try:
|
||||||
async with session.post(url=url, json=payload, headers=headers) as response:
|
async with session.post(url=url, json=payload, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
|
prompt_tokens = 0
|
||||||
|
cached_tokens = 0
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
@@ -168,6 +172,12 @@ async def async_request_sglang_generate(
|
|||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
prompt_tokens = (data.get("meta_info") or {}).get(
|
||||||
|
"prompt_tokens", 0
|
||||||
|
)
|
||||||
|
cached_tokens = (data.get("meta_info") or {}).get(
|
||||||
|
"cached_tokens", 0
|
||||||
|
)
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
@@ -179,6 +189,8 @@ async def async_request_sglang_generate(
|
|||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = latency
|
||||||
|
output.prompt_len = prompt_tokens
|
||||||
|
output.cached_tokens = cached_tokens
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len):
|
|||||||
"ignore_eos": True,
|
"ignore_eos": True,
|
||||||
},
|
},
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
"lora_path": "",
|
"lora_path": "",
|
||||||
"return_logprob": False,
|
"return_logprob": False,
|
||||||
"logprob_start_len": -1,
|
"logprob_start_len": -1,
|
||||||
@@ -303,7 +316,12 @@ class WorkloadGenerator:
|
|||||||
|
|
||||||
self.response_queue = queue.Queue()
|
self.response_queue = queue.Queue()
|
||||||
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
||||||
self.performance_metrics = {"ttft": [], "latency": []}
|
self.performance_metrics = {
|
||||||
|
"ttft": [],
|
||||||
|
"latency": [],
|
||||||
|
"prompt_len": [],
|
||||||
|
"cached_tokens": [],
|
||||||
|
}
|
||||||
|
|
||||||
async def handle_request(self, item):
|
async def handle_request(self, item):
|
||||||
try:
|
try:
|
||||||
@@ -360,6 +378,8 @@ class WorkloadGenerator:
|
|||||||
self.client_records[client_id]["round"] += 1
|
self.client_records[client_id]["round"] += 1
|
||||||
self.performance_metrics["ttft"].append(response.ttft)
|
self.performance_metrics["ttft"].append(response.ttft)
|
||||||
self.performance_metrics["latency"].append(response.latency)
|
self.performance_metrics["latency"].append(response.latency)
|
||||||
|
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||||
|
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||||
self.completed_requests += 1
|
self.completed_requests += 1
|
||||||
|
|
||||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||||
@@ -416,6 +436,12 @@ class WorkloadGenerator:
|
|||||||
len(self.performance_metrics["latency"]) // 2
|
len(self.performance_metrics["latency"]) // 2
|
||||||
],
|
],
|
||||||
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
||||||
|
"cache_hit_rate": (
|
||||||
|
0
|
||||||
|
if sum(self.performance_metrics["prompt_len"]) == 0
|
||||||
|
else sum(self.performance_metrics["cached_tokens"])
|
||||||
|
/ sum(self.performance_metrics["prompt_len"])
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
print("All requests completed")
|
print("All requests completed")
|
||||||
@@ -434,6 +460,7 @@ class WorkloadGenerator:
|
|||||||
print(
|
print(
|
||||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||||
)
|
)
|
||||||
|
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||||
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from queue import Empty, Full, PriorityQueue, Queue
|
from queue import Empty, Full, PriorityQueue, Queue
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
|
|||||||
self._done_flag = False
|
self._done_flag = False
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
self.start_time = time.monotonic()
|
||||||
|
|
||||||
super().__init__(host_indices, token_ids, last_hash)
|
super().__init__(host_indices, token_ids, last_hash)
|
||||||
|
|
||||||
def increment(self, num_tokens: int):
|
def increment(self, num_tokens: int):
|
||||||
@@ -278,6 +281,12 @@ class HiCacheController:
|
|||||||
self.enable_storage = True
|
self.enable_storage = True
|
||||||
# todo: threshold policy for prefetching
|
# todo: threshold policy for prefetching
|
||||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||||
|
self.prefetch_capacity_limit = int(
|
||||||
|
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
||||||
|
)
|
||||||
|
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
||||||
|
self.prefetch_tokens_occupied = 0
|
||||||
|
|
||||||
# create a new communication group for synchronizing storage operations across TP workers
|
# create a new communication group for synchronizing storage operations across TP workers
|
||||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
@@ -525,7 +534,7 @@ class HiCacheController:
|
|||||||
host_indices: torch.Tensor,
|
host_indices: torch.Tensor,
|
||||||
new_input_tokens: List[int],
|
new_input_tokens: List[int],
|
||||||
last_hash: Optional[str] = None,
|
last_hash: Optional[str] = None,
|
||||||
) -> int:
|
) -> PrefetchOperation:
|
||||||
"""
|
"""
|
||||||
Prefetch KV caches from storage backend to host memory.
|
Prefetch KV caches from storage backend to host memory.
|
||||||
"""
|
"""
|
||||||
@@ -586,11 +595,23 @@ class HiCacheController:
|
|||||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||||
if self.is_mooncake_backend():
|
if self.is_mooncake_backend():
|
||||||
self.mooncake_page_transfer(operation)
|
self.mooncake_page_transfer(operation)
|
||||||
|
elif self.storage_backend_type == "hf3fs":
|
||||||
|
self.generic_page_transfer(operation, batch_size=128)
|
||||||
else:
|
else:
|
||||||
self.generic_page_transfer(operation)
|
self.generic_page_transfer(operation)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
def prefetch_rate_limit_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
||||||
|
"""
|
||||||
|
# cancel prefetch if too much memory is occupied
|
||||||
|
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
||||||
|
return False
|
||||||
|
# todo: more sophisticated rate limiting based on storage backend performance
|
||||||
|
return True
|
||||||
|
|
||||||
def prefetch_thread_func(self):
|
def prefetch_thread_func(self):
|
||||||
"""
|
"""
|
||||||
Manage prefetching operations from storage backend to host memory.
|
Manage prefetching operations from storage backend to host memory.
|
||||||
@@ -604,34 +625,36 @@ class HiCacheController:
|
|||||||
if operation is None:
|
if operation is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
last_hash = operation.last_hash
|
|
||||||
tokens_to_fetch = operation.token_ids
|
|
||||||
|
|
||||||
storage_hit_count = 0
|
storage_hit_count = 0
|
||||||
remaining_tokens = len(tokens_to_fetch)
|
if self.prefetch_rate_limit_check():
|
||||||
hash_value = []
|
last_hash = operation.last_hash
|
||||||
while remaining_tokens >= self.page_size:
|
tokens_to_fetch = operation.token_ids
|
||||||
last_hash = self.get_hash_str(
|
|
||||||
tokens_to_fetch[
|
|
||||||
storage_hit_count : storage_hit_count + self.page_size
|
|
||||||
],
|
|
||||||
last_hash,
|
|
||||||
)
|
|
||||||
|
|
||||||
# todo, more unified interface
|
remaining_tokens = len(tokens_to_fetch)
|
||||||
if not self.is_mooncake_backend():
|
hash_value = []
|
||||||
if not self.storage_backend.exists(last_hash):
|
while remaining_tokens >= self.page_size:
|
||||||
break
|
last_hash = self.get_hash_str(
|
||||||
hash_value.append(last_hash)
|
tokens_to_fetch[
|
||||||
storage_hit_count += self.page_size
|
storage_hit_count : storage_hit_count + self.page_size
|
||||||
remaining_tokens -= self.page_size
|
],
|
||||||
|
last_hash,
|
||||||
|
)
|
||||||
|
|
||||||
if self.is_mooncake_backend():
|
# todo, more unified interface
|
||||||
# deferring to batch exists for mooncake store
|
if not self.is_mooncake_backend():
|
||||||
exist_result = self.storage_backend.exists(hash_value)
|
if not self.storage_backend.exists(last_hash):
|
||||||
storage_hit_count = (
|
break
|
||||||
sum(1 for v in exist_result.values() if v != 0) * self.page_size
|
hash_value.append(last_hash)
|
||||||
)
|
storage_hit_count += self.page_size
|
||||||
|
remaining_tokens -= self.page_size
|
||||||
|
|
||||||
|
if self.is_mooncake_backend():
|
||||||
|
# deferring to batch exists for mooncake store
|
||||||
|
exist_result = self.storage_backend.exists(hash_value)
|
||||||
|
storage_hit_count = (
|
||||||
|
sum(1 for v in exist_result.values() if v != 0)
|
||||||
|
* self.page_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
storage_hit_count_tensor = torch.tensor(
|
storage_hit_count_tensor = torch.tensor(
|
||||||
@@ -750,6 +773,8 @@ class HiCacheController:
|
|||||||
|
|
||||||
if self.is_mooncake_backend():
|
if self.is_mooncake_backend():
|
||||||
self.mooncake_page_backup(operation)
|
self.mooncake_page_backup(operation)
|
||||||
|
elif self.storage_backend_type == "hf3fs":
|
||||||
|
self.generic_page_backup(operation, batch_size=128)
|
||||||
else:
|
else:
|
||||||
self.generic_page_backup(operation)
|
self.generic_page_backup(operation)
|
||||||
|
|
||||||
|
|||||||
@@ -619,6 +619,7 @@ class Scheduler(
|
|||||||
),
|
),
|
||||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||||
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
||||||
)
|
)
|
||||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||||
self.tree_cache.cache_controller.layer_done_counter
|
self.tree_cache.cache_controller.layer_done_counter
|
||||||
@@ -1572,7 +1573,10 @@ class Scheduler(
|
|||||||
break
|
break
|
||||||
|
|
||||||
if self.enable_hicache_storage:
|
if self.enable_hicache_storage:
|
||||||
self.tree_cache.check_prefetch_progress(req.rid)
|
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
|
||||||
|
if not prefetch_done:
|
||||||
|
# skip staging requests that are ongoing prefetch
|
||||||
|
continue
|
||||||
|
|
||||||
req.init_next_round_input(self.tree_cache)
|
req.init_next_round_input(self.tree_cache)
|
||||||
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ import heapq
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from queue import Queue
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_io_backend: str,
|
hicache_io_backend: str,
|
||||||
hicache_mem_layout: str,
|
hicache_mem_layout: str,
|
||||||
hicache_storage_backend: Optional[str] = None,
|
hicache_storage_backend: Optional[str] = None,
|
||||||
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||||
):
|
):
|
||||||
|
|
||||||
if hicache_io_backend == "direct":
|
if hicache_io_backend == "direct":
|
||||||
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
|
|||||||
prefetch_threshold=self.prefetch_threshold,
|
prefetch_threshold=self.prefetch_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||||
|
# todo: customizable storage prefetch timeout
|
||||||
|
self.prefetch_timeout = 3 # seconds
|
||||||
|
logger.info(
|
||||||
|
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
||||||
|
)
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
self.ongoing_write_through = {}
|
self.ongoing_write_through = {}
|
||||||
# record the node segments with ongoing load back
|
# record the node segments with ongoing load back
|
||||||
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
|
|||||||
for _ in range(queue_size.item()):
|
for _ in range(queue_size.item()):
|
||||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||||
if req_id in self.ongoing_prefetch:
|
if req_id in self.ongoing_prefetch:
|
||||||
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
|
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
||||||
last_host_node.release_host()
|
last_host_node.release_host()
|
||||||
del self.ongoing_prefetch[req_id]
|
del self.ongoing_prefetch[req_id]
|
||||||
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||||
else:
|
else:
|
||||||
# the revoked operation already got terminated
|
# the revoked operation already got terminated
|
||||||
pass
|
pass
|
||||||
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
|
|||||||
host_node.release_host()
|
host_node.release_host()
|
||||||
del self.ongoing_backup[ack_id]
|
del self.ongoing_backup[ack_id]
|
||||||
|
|
||||||
def check_prefetch_progress(self, req_id: str):
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||||
|
can_terminate = True
|
||||||
|
|
||||||
|
if self.prefetch_stop_policy == "best_effort":
|
||||||
|
return can_terminate
|
||||||
|
|
||||||
|
completed = (
|
||||||
|
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.prefetch_stop_policy == "wait_complete":
|
||||||
|
can_terminate = completed
|
||||||
|
elif self.prefetch_stop_policy == "timeout":
|
||||||
|
can_terminate = completed or (
|
||||||
|
time.monotonic() - operation.start_time > self.prefetch_timeout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# unknown prefetch stop policy, just return True
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
can_terminate,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_group,
|
||||||
|
)
|
||||||
|
can_terminate = bool(can_terminate.item())
|
||||||
|
|
||||||
|
return can_terminate
|
||||||
|
|
||||||
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
||||||
if req_id not in self.ongoing_prefetch:
|
if req_id not in self.ongoing_prefetch:
|
||||||
# there is no ongoing prefetch for this request or it has been revoked
|
# there is no ongoing prefetch for this request or it has been revoked
|
||||||
return
|
return True
|
||||||
|
|
||||||
# todo: more policies for prefetch progress such as timeout
|
# todo: more policies for prefetch progress such as timeout
|
||||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||||
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
|
|||||||
req_id
|
req_id
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not self.can_terminate_prefetch(operation):
|
||||||
|
return False
|
||||||
|
|
||||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||||
operation
|
operation
|
||||||
)
|
)
|
||||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||||
|
|
||||||
min_completed_tokens = completed_tokens
|
min_completed_tokens = completed_tokens
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
||||||
# synchrnoize TP workers to make the same update to hiradix cache
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
completed_tokens_tensor = torch.tensor(
|
completed_tokens_tensor = torch.tensor(
|
||||||
min_completed_tokens, dtype=torch.int
|
min_completed_tokens, dtype=torch.int
|
||||||
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
|
|||||||
)
|
)
|
||||||
last_host_node.release_host()
|
last_host_node.release_host()
|
||||||
del self.ongoing_prefetch[req_id]
|
del self.ongoing_prefetch[req_id]
|
||||||
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], **kwargs):
|
def match_prefix(self, key: List[int], **kwargs):
|
||||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
|
|||||||
host_indices,
|
host_indices,
|
||||||
operation,
|
operation,
|
||||||
)
|
)
|
||||||
|
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
||||||
|
|
||||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||||
node.last_access_time = time.monotonic()
|
node.last_access_time = time.monotonic()
|
||||||
|
|||||||
@@ -96,6 +96,8 @@ class Hf3fsClient:
|
|||||||
)
|
)
|
||||||
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
||||||
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
||||||
|
self.shm_r.unlink()
|
||||||
|
self.shm_w.unlink()
|
||||||
|
|
||||||
self.rlock = threading.RLock()
|
self.rlock = threading.RLock()
|
||||||
self.wlock = threading.RLock()
|
self.wlock = threading.RLock()
|
||||||
@@ -176,8 +178,6 @@ class Hf3fsClient:
|
|||||||
del self.iov_w
|
del self.iov_w
|
||||||
self.shm_r.close()
|
self.shm_r.close()
|
||||||
self.shm_w.close()
|
self.shm_w.close()
|
||||||
self.shm_r.unlink()
|
|
||||||
self.shm_w.unlink()
|
|
||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
os.fsync(self.file)
|
os.fsync(self.file)
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ class ServerArgs:
|
|||||||
hicache_io_backend: str = "kernel"
|
hicache_io_backend: str = "kernel"
|
||||||
hicache_mem_layout: str = "layer_first"
|
hicache_mem_layout: str = "layer_first"
|
||||||
hicache_storage_backend: Optional[str] = None
|
hicache_storage_backend: Optional[str] = None
|
||||||
|
hicache_storage_prefetch_policy: str = "best_effort"
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
enable_double_sparsity: bool = False
|
enable_double_sparsity: bool = False
|
||||||
@@ -1626,6 +1627,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_storage_backend,
|
default=ServerArgs.hicache_storage_backend,
|
||||||
help="The storage backend for hierarchical KV cache.",
|
help="The storage backend for hierarchical KV cache.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hicache-storage-prefetch-policy",
|
||||||
|
type=str,
|
||||||
|
choices=["best_effort", "wait_complete", "timeout"],
|
||||||
|
default=ServerArgs.hicache_storage_prefetch_policy,
|
||||||
|
help="Control when prefetching from the storage backend should stop.",
|
||||||
|
)
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user