HiCache, add bench long context plus minor fixs (#9086)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
96
benchmark/hicache/bench_long_context.py
Normal file
96
benchmark/hicache/bench_long_context.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from bench_multiturn import (
|
||||||
|
ReadyQueue,
|
||||||
|
WorkloadGenerator,
|
||||||
|
gen_payload,
|
||||||
|
log_to_jsonl_file,
|
||||||
|
parse_args,
|
||||||
|
)
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
|
from sglang.bench_serving import get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWorkloadGenerator(WorkloadGenerator):
|
||||||
|
def __init__(self, args):
|
||||||
|
# Construct the base URL for requests
|
||||||
|
self.baseurl = f"http://{args.host}:{args.port}/"
|
||||||
|
self.url = self.baseurl + "generate"
|
||||||
|
|
||||||
|
self.tokenizer = get_tokenizer(args.model_path)
|
||||||
|
self.distribution = args.distribution
|
||||||
|
self.request_rate = args.request_rate
|
||||||
|
self.start_time = None
|
||||||
|
self.finished_time = None
|
||||||
|
|
||||||
|
self.sent_requests = 0
|
||||||
|
self.completed_requests = 0
|
||||||
|
|
||||||
|
self.dataset = json.load(open(args.dataset_path))
|
||||||
|
|
||||||
|
init_requests = []
|
||||||
|
for i in range(min(args.num_clients, len(self.dataset["queries"]))):
|
||||||
|
context_id = self.dataset["queries"][i]["context"]
|
||||||
|
init_requests.append(
|
||||||
|
(
|
||||||
|
i,
|
||||||
|
gen_payload(
|
||||||
|
self.dataset["contexts"][context_id]
|
||||||
|
+ self.dataset["queries"][i]["question"],
|
||||||
|
len(
|
||||||
|
self.tokenizer(
|
||||||
|
self.dataset["queries"][i]["reference_answer"]
|
||||||
|
)["input_ids"]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
||||||
|
|
||||||
|
self.response_queue = queue.Queue()
|
||||||
|
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
||||||
|
self.performance_metrics = {
|
||||||
|
"ttft": [],
|
||||||
|
"latency": [],
|
||||||
|
"itl": [],
|
||||||
|
"prompt_len": [],
|
||||||
|
"cached_tokens": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.max_parallel = args.max_parallel
|
||||||
|
self.logfile = args.log_file
|
||||||
|
|
||||||
|
def response_handler(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
client_id, response = self.response_queue.get(
|
||||||
|
timeout=10
|
||||||
|
) # Block until response is available
|
||||||
|
if not response.success:
|
||||||
|
raise ValueError(f"Request failed with error: {response.error}")
|
||||||
|
self.performance_metrics["ttft"].append(response.ttft)
|
||||||
|
self.performance_metrics["itl"].extend(response.itl)
|
||||||
|
self.performance_metrics["latency"].append(response.latency)
|
||||||
|
self.completed_requests += 1
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
if self.pbar.n == self.pbar.total:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
args.num_rounds = 1
|
||||||
|
args.max_parallel = 128
|
||||||
|
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||||
|
|
||||||
|
for request_rate in [24, 16, 12, 8, 4, 2, 1]:
|
||||||
|
args.request_rate = request_rate
|
||||||
|
requests.post(flush_cache_url)
|
||||||
|
time.sleep(1)
|
||||||
|
performance_data = ContextWorkloadGenerator(args).run()
|
||||||
|
log_to_jsonl_file(performance_data, args.log_file, args.tag)
|
||||||
@@ -322,6 +322,9 @@ class WorkloadGenerator:
|
|||||||
"prompt_len": [],
|
"prompt_len": [],
|
||||||
"cached_tokens": [],
|
"cached_tokens": [],
|
||||||
}
|
}
|
||||||
|
self.num_rounds = args.num_rounds
|
||||||
|
self.max_parallel = args.max_parallel
|
||||||
|
self.output_length = args.output_length
|
||||||
|
|
||||||
async def handle_request(self, item):
|
async def handle_request(self, item):
|
||||||
try:
|
try:
|
||||||
@@ -336,7 +339,7 @@ class WorkloadGenerator:
|
|||||||
def request_sender(self):
|
def request_sender(self):
|
||||||
async def request_loop():
|
async def request_loop():
|
||||||
while True:
|
while True:
|
||||||
if self.sent_requests - self.completed_requests < args.max_parallel:
|
if self.sent_requests - self.completed_requests < self.max_parallel:
|
||||||
new_request = self.ready_queue.pop()
|
new_request = self.ready_queue.pop()
|
||||||
if new_request:
|
if new_request:
|
||||||
asyncio.create_task(self.handle_request(new_request))
|
asyncio.create_task(self.handle_request(new_request))
|
||||||
@@ -382,7 +385,7 @@ class WorkloadGenerator:
|
|||||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||||
self.completed_requests += 1
|
self.completed_requests += 1
|
||||||
|
|
||||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
if self.client_records[client_id]["round"] < self.num_rounds:
|
||||||
# append new request to client's history
|
# append new request to client's history
|
||||||
self.client_records[client_id][
|
self.client_records[client_id][
|
||||||
"history"
|
"history"
|
||||||
@@ -392,7 +395,7 @@ class WorkloadGenerator:
|
|||||||
client_id,
|
client_id,
|
||||||
gen_payload(
|
gen_payload(
|
||||||
self.client_records[client_id]["history"],
|
self.client_records[client_id]["history"],
|
||||||
args.output_length,
|
self.output_length,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -461,7 +464,7 @@ class WorkloadGenerator:
|
|||||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||||
)
|
)
|
||||||
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||||
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
return performance_data
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -482,4 +485,5 @@ if __name__ == "__main__":
|
|||||||
args.request_rate = rate
|
args.request_rate = rate
|
||||||
requests.post(flush_cache_url)
|
requests.post(flush_cache_url)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
WorkloadGenerator(args).run()
|
performance_data = WorkloadGenerator(args).run()
|
||||||
|
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
||||||
|
|||||||
@@ -44,9 +44,9 @@ Look for log entries like this:
|
|||||||
[2025-08-11 17:17:03] max_total_num_tokens=665690, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=65536, available_gpu_mem=13.50 GB
|
[2025-08-11 17:17:03] max_total_num_tokens=665690, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=65536, available_gpu_mem=13.50 GB
|
||||||
```
|
```
|
||||||
|
|
||||||
Check the `available_gpu_mem` value.
|
Check the `available_gpu_mem` value.
|
||||||
- If it is between 5–8 GB, the setting is good.
|
- If it is between 5–8 GB, the setting is good.
|
||||||
- If it is too high (e.g., 10 - 20 GB), increase `--mem-fraction-static` to allocate more memory to the KV cache.
|
- If it is too high (e.g., 10 - 20 GB), increase `--mem-fraction-static` to allocate more memory to the KV cache.
|
||||||
- If it is too low, you risk out-of-memory (OOM) errors later, so decrease `--mem-fraction-static`.
|
- If it is too low, you risk out-of-memory (OOM) errors later, so decrease `--mem-fraction-static`.
|
||||||
|
|
||||||
Another straightforward approach is to increase `--mem-fraction-static` in increments of 0.01 until you encounter OOM errors for your workloads.
|
Another straightforward approach is to increase `--mem-fraction-static` in increments of 0.01 until you encounter OOM errors for your workloads.
|
||||||
|
|||||||
@@ -71,8 +71,10 @@ class HiRadixCache(RadixCache):
|
|||||||
self.tp_group = tp_cache_group
|
self.tp_group = tp_cache_group
|
||||||
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
||||||
self.enable_storage = hicache_storage_backend is not None
|
self.enable_storage = hicache_storage_backend is not None
|
||||||
# todo: customizable storage prefetch threshold
|
# todo: customizable storage prefetch threshold and timeout
|
||||||
self.prefetch_threshold = 256
|
self.prefetch_threshold = 256
|
||||||
|
self.prefetch_timeout = 3 # seconds
|
||||||
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||||
|
|
||||||
self.load_cache_event = threading.Event()
|
self.load_cache_event = threading.Event()
|
||||||
self.cache_controller = HiCacheController(
|
self.cache_controller = HiCacheController(
|
||||||
@@ -87,13 +89,6 @@ class HiRadixCache(RadixCache):
|
|||||||
prefetch_threshold=self.prefetch_threshold,
|
prefetch_threshold=self.prefetch_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
|
||||||
# todo: customizable storage prefetch timeout
|
|
||||||
self.prefetch_timeout = 3 # seconds
|
|
||||||
logger.info(
|
|
||||||
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
self.ongoing_write_through = {}
|
self.ongoing_write_through = {}
|
||||||
# record the node segments with ongoing load back
|
# record the node segments with ongoing load back
|
||||||
|
|||||||
Reference in New Issue
Block a user