diff --git a/benchmark/hicache/README.md b/benchmark/hicache/README.md new file mode 100644 index 000000000..b452a48a3 --- /dev/null +++ b/benchmark/hicache/README.md @@ -0,0 +1,25 @@ +## Run synthetic multi-turn benchmark + +``` +# SGLang server with radix cache disabled +python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache + +# SGLang server with radix cache on and first-come-first-serve policy +python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs + +# The default SGLang server with radix cache on and long-prefix-match policy +python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 + +# SGLang server with hierarchical radix cache enabled +python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache + +``` + +``` +python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct +``` + +Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching. + + +## More benchmarks to be added diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index ab34c33da..1fb58e024 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -5,6 +5,7 @@ import queue import random import threading import time +from datetime import datetime from typing import Optional import aiohttp @@ -26,9 +27,15 @@ def parse_args(): parser.add_argument( "--num-clients", type=int, - default=200, + default=256, help="Number of concurrent clients", ) + parser.add_argument( + "--max-parallel", + type=int, + default=128, + help="Maximum number of parallel requests", + ) parser.add_argument( "--request-length", type=int, @@ -73,11 +80,17 @@ def parse_args(): help="Server port (default: 30000)", ) parser.add_argument( - "--model", + "--model-path", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="model path compatible with Hugging Face Transformers", ) + parser.add_argument( + "--log-file", + type=str, + default="performance_metrics.jsonl", + help="File to log performance metrics", + ) return parser.parse_args() @@ -158,6 +171,18 @@ def gen_payload(prompt, output_len): return payload +def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"): + """Append the data with a timestamp to the specified JSONL file.""" + timestamped_data = {"timestamp": datetime.now().isoformat(), **data} + try: + with open(file_path, "a") as file: + file.write( + json.dumps(timestamped_data) + "\n" + ) # Write as a single line in JSONL format + except IOError as e: + print(f"Error writing to JSONL file: {e}") + + class ReadyQueue: """ Thread-safe queue that can pop requests in different orders based on given policy. @@ -191,12 +216,15 @@ class WorkloadGenerator: # Construct the base URL for requests self.url = f"http://{args.host}:{args.port}/generate" - self.tokenizer = get_tokenizer(args.model) + self.tokenizer = get_tokenizer(args.model_path) self.distribution = args.distribution self.request_rate = args.request_rate self.start_time = None self.finished_time = None + self.sent_requests = 0 + self.completed_requests = 0 + self.candidate_inputs = sample_random_requests( input_len=args.request_length, output_len=args.output_length, @@ -235,6 +263,18 @@ class WorkloadGenerator: def request_sender(self): async def request_loop(): while True: + if self.sent_requests - self.completed_requests < args.max_parallel: + new_request = self.ready_queue.pop() + if new_request: + asyncio.create_task(self.handle_request(new_request)) + self.sent_requests += 1 + else: + await asyncio.sleep(0.05) + continue + + if self.pbar.n == self.pbar.total: + break + # Calculate Poisson-distributed wait time if self.distribution == "poisson": sleep_time = random.expovariate(self.request_rate) @@ -247,14 +287,6 @@ class WorkloadGenerator: raise ValueError("Invalid distribution type") await asyncio.sleep(sleep_time) # Wait before sending the next request - new_request = self.ready_queue.pop() - # Submit async request - if new_request: - asyncio.create_task(self.handle_request(new_request)) - else: - if self.pbar.n == self.pbar.total: - break - # Create and run the event loop for asynchronous requests loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -273,6 +305,7 @@ class WorkloadGenerator: self.client_records[client_id]["round"] += 1 self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["latency"].append(response.latency) + self.completed_requests += 1 if self.client_records[client_id]["round"] < args.num_rounds: self.client_records[client_id][ @@ -301,34 +334,56 @@ class WorkloadGenerator: request_thread.join() response_thread.join() - self.pbar.close() - print("All requests completed.") + + performance_data = { + "summary": { + "total_requests": len(self.performance_metrics["ttft"]), + "request_rate": self.request_rate, + "average_ttft": sum(self.performance_metrics["ttft"]) + / len(self.performance_metrics["ttft"]), + "p90_ttft": sorted(self.performance_metrics["ttft"])[ + int(0.9 * len(self.performance_metrics["ttft"])) + ], + "median_ttft": sorted(self.performance_metrics["ttft"])[ + len(self.performance_metrics["ttft"]) // 2 + ], + "average_latency": sum(self.performance_metrics["latency"]) + / len(self.performance_metrics["latency"]), + "p90_latency": sorted(self.performance_metrics["latency"])[ + int(0.9 * len(self.performance_metrics["latency"])) + ], + "median_latency": sorted(self.performance_metrics["latency"])[ + len(self.performance_metrics["latency"]) // 2 + ], + "throughput": self.pbar.total / (self.finished_time - self.start_time), + }, + } + print("All requests completed") print("Performance metrics summary:") print( - f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second" + f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second" ) + print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}") + print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}") + print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}") print( - f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}" + f" Average latency: {performance_data['summary']['average_latency']:.2f}" ) + print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}") + print(f" Median latency: {performance_data['summary']['median_latency']:.2f}") print( - f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}" + f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) - print( - f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}" - ) - print( - f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}" - ) - throughput = self.pbar.total / (self.finished_time - self.start_time) - print(f"Throughput: {throughput:.2f} requests per second") + log_to_jsonl_file(performance_data, args.log_file) if __name__ == "__main__": args = parse_args() flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" - for request_rate in range(1, 41, 2): + for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]: args.request_rate = request_rate requests.post(flush_cache_url) + time.sleep(1) WorkloadGenerator(args).run() diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 4560a2708..1a8ad6a6e 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. """ +import concurrent.futures import logging +import math import threading -from queue import PriorityQueue, Queue -from typing import Optional +from queue import Empty, Full, PriorityQueue, Queue +from typing import List, Optional import torch @@ -55,6 +55,27 @@ class CacheOperation: self.priority = min(self.priority, other.priority) self.node_ids.extend(other.node_ids) + def split(self, factor) -> List["CacheOperation"]: + # split an operation into smaller operations to reduce the size of intermediate buffers + if factor <= 1: + return [self] + + chunk_size = math.ceil(len(self.host_indices) / factor) + split_ops = [] + for i in range(0, len(self.host_indices), chunk_size): + split_ops.append( + CacheOperation( + host_indices=self.host_indices[i : i + chunk_size], + device_indices=self.device_indices[i : i + chunk_size], + node_id=0, + ) + ) + # Inherit the node_ids on the final chunk + if split_ops: + split_ops[-1].node_ids = self.node_ids + + return split_ops + def __lt__(self, other: "CacheOperation"): return self.priority < other.priority @@ -64,7 +85,10 @@ class TransferBuffer: Overlapping buffer preparation and transfer operations to improve throughput. """ - def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None: + def __init__( + self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000 + ) -> None: + self.stop_event = stop_event self.buffers = Queue(maxsize=buffer_count) # todo: adjust the buffer size based on throughput profile of the system self.max_buffer_size = max_buffer_size @@ -75,15 +99,29 @@ class TransferBuffer: def empty(self) -> bool: return self.buffers.empty() - def put(self, item, block=True) -> None: - self.buffers.put(item, block=block) + def put(self, item, block=True, timeout=1) -> None: + while not self.stop_event.is_set(): + try: + self.buffers.put(item, block=block, timeout=timeout) + break + except Full: + if not block: + break + continue + except Exception as e: + logger.error(e) - def get(self, block=True) -> Optional[CacheOperation]: + def get(self, block=True, timeout=1) -> Optional[CacheOperation]: try: - return self.buffers.get(block=block) + return self.buffers.get(block=block, timeout=timeout) + except Empty: + return None except Exception as e: logger.error(e) + def clear(self): + self.buffers.queue.clear() + class HiCacheController: @@ -111,8 +149,11 @@ class HiCacheController: self.ack_write_queue = Queue() self.ack_load_queue = Queue() - self.write_buffer = TransferBuffer() - self.load_buffer = TransferBuffer() + self.stop_event = threading.Event() + self.write_buffer = TransferBuffer(self.stop_event) + self.load_buffer = TransferBuffer( + self.stop_event, buffer_count=10, max_buffer_size=100 + ) self.write_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream() @@ -126,6 +167,28 @@ class HiCacheController: self.write_thread.start() self.load_thread.start() + def reset(self): + self.stop_event.set() + self.write_thread.join() + self.load_thread.join() + + self.write_queue.queue.clear() + self.load_queue.queue.clear() + self.write_buffer.clear() + self.load_buffer.clear() + self.ack_write_queue.queue.clear() + self.ack_load_queue.queue.clear() + + self.write_thread = threading.Thread( + target=self.write_thread_func_buffer, daemon=True + ) + self.load_thread = threading.Thread( + target=self.load_thread_func_buffer, daemon=True + ) + self.stop_event.clear() + self.write_thread.start() + self.load_thread.start() + def write( self, device_indices: torch.Tensor, @@ -138,10 +201,10 @@ class HiCacheController: host_indices = self.mem_pool_host.alloc(len(device_indices)) if host_indices is None: return None + self.mem_pool_host.protect_write(host_indices) self.write_queue.put( CacheOperation(host_indices, device_indices, node_id, priority) ) - self.mem_pool_host.protect_write(host_indices) return host_indices def load( @@ -156,10 +219,10 @@ class HiCacheController: device_indices = self.mem_pool_device.alloc(len(host_indices)) if device_indices is None: return None + self.mem_pool_host.protect_load(host_indices) self.load_queue.put( CacheOperation(host_indices, device_indices, node_id, priority) ) - self.mem_pool_host.protect_load(host_indices) return device_indices def write_thread_func_direct(self): @@ -167,16 +230,19 @@ class HiCacheController: Directly write through KV caches to host memory without buffering. """ with torch.cuda.stream(self.write_stream): - while True: + while not self.stop_event.is_set(): try: - operation = self.write_queue.get(block=True) + operation = self.write_queue.get(block=True, timeout=1) operation.data = self.mem_pool_device.get_flat_data( operation.device_indices ) self.mem_pool_host.transfer(operation.host_indices, operation.data) self.mem_pool_host.complete_io(operation.host_indices) for node_id in operation.node_ids: - self.ack_write_queue.put(node_id) + if node_id != 0: + self.ack_write_queue.put(node_id) + except Empty: + continue except Exception as e: logger.error(e) @@ -185,9 +251,10 @@ class HiCacheController: Directly load KV caches from host memory to device memory without buffering. """ with torch.cuda.stream(self.load_stream): - while True: + while not self.stop_event.is_set(): try: - operation = self.load_queue.get(block=True) + operation = self.load_queue.get(block=True, timeout=1) + # time.sleep(18e-6 * len(operation.host_indices)) operation.data = self.mem_pool_host.get_flat_data( operation.host_indices ) @@ -196,7 +263,10 @@ class HiCacheController: ) self.mem_pool_host.complete_io(operation.host_indices) for node_id in operation.node_ids: - self.ack_load_queue.put(node_id) + if node_id != 0: + self.ack_load_queue.put(node_id) + except Empty: + continue except Exception as e: logger.error(e) @@ -204,39 +274,98 @@ class HiCacheController: """ Auxiliary function to prepare the buffer for write operations. """ + + def _to_op(op_): + assert op_.device_indices.is_cuda, "Device indices should be on GPU" + op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to( + self.mem_pool_host.device + ) + self.write_buffer.put(op_) + return op_ + buffer = None - while True: - try: - operation = self.write_queue.get(block=True) - if buffer is None: - buffer = operation - else: - buffer.merge(operation) - if ( - no_wait - or len(buffer.host_indices) >= self.write_buffer.max_buffer_size - or self.write_queue.empty() - or self.write_buffer.empty() - ): - assert ( - buffer.device_indices.is_cuda - ), "Device indices should be on GPU" - buffer.data = self.mem_pool_device.get_flat_data( - buffer.device_indices - ).contiguous() - self.write_buffer.put(buffer, block=True) - buffer = None - except Exception as e: - logger.error(e) + with torch.cuda.stream(self.write_stream): + while not self.stop_event.is_set(): + try: + operation = self.write_queue.get(block=True, timeout=1) + factor = ( + len(operation.device_indices) + // self.write_buffer.max_buffer_size + ) + + if factor >= 1: + if buffer is not None: + _to_op(buffer) + buffer = None + + if factor < 2: + _to_op(operation) + else: + split_ops = operation.split(factor) + for op_ in split_ops: + _to_op(op_) + continue + + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + no_wait + or len(buffer.host_indices) >= self.write_buffer.max_buffer_size + or self.write_queue.empty() + or self.write_buffer.empty() + ): + _to_op(buffer) + buffer = None + except Empty: + continue + except Exception as e: + logger.error(e) def load_aux_func(self): """ Auxiliary function to prepare the buffer for load operations. """ + + def _pin_op(op_, put=True): + op_.data = ( + self.mem_pool_host.get_flat_data(op_.host_indices) + .contiguous() + .pin_memory() + ) + if put: + self.load_buffer.put(op_) + return op_ + buffer = None - while True: + while not self.stop_event.is_set(): try: - operation = self.load_queue.get(block=True) + operation = self.load_queue.get(block=True, timeout=1) + factor = len(operation.host_indices) // self.load_buffer.max_buffer_size + + if factor >= 1: + if buffer is not None: + _pin_op(buffer) + buffer = None + + if factor < 2: + _pin_op(operation) + else: + split_ops = operation.split(factor) + split_args = [(op_, True) for op_ in split_ops[:-1]] + split_args.append((split_ops[-1], False)) + # Spawn threads to pin each op concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + pinned_ops = list( + executor.map( + lambda x: _pin_op(x[0], put=x[1]), split_args + ) + ) + # preserve the order of last op to ensure correct ack + self.load_buffer.put(pinned_ops[-1]) + continue + if buffer is None: buffer = operation else: @@ -246,41 +375,43 @@ class HiCacheController: or self.load_queue.empty() or self.load_buffer.empty() ): - buffer.data = ( - self.mem_pool_host.get_flat_data(buffer.host_indices) - .contiguous() - .pin_memory() - ) - self.load_buffer.put(buffer, block=True) + _pin_op(buffer) buffer = None + except Empty: + continue except Exception as e: logger.error(e) def write_thread_func_buffer(self): aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) aux_thread.start() - with torch.cuda.stream(self.write_stream): - while True: - operation = self.write_buffer.get() - if operation is None: - continue - self.mem_pool_host.transfer(operation.host_indices, operation.data) - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: + + while not self.stop_event.is_set(): + operation = self.write_buffer.get() + if operation is None: + continue + self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + if node_id != 0: self.ack_write_queue.put(node_id) + aux_thread.join() def load_thread_func_buffer(self): aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) aux_thread.start() + with torch.cuda.stream(self.load_stream): - while True: + while not self.stop_event.is_set(): operation = self.load_buffer.get() if operation is None: continue self.mem_pool_device.transfer(operation.device_indices, operation.data) self.mem_pool_host.complete_io(operation.host_indices) for node_id in operation.node_ids: - self.ack_load_queue.put(node_id) + if node_id != 0: + self.ack_load_queue.put(node_id) + aux_thread.join() def evict_device( self, device_indices: torch.Tensor, host_indices: torch.Tensor diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a2f8b0479..b9f02bc61 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -82,6 +82,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode @@ -300,16 +301,24 @@ class Scheduler: token_to_kv_pool=self.token_to_kv_pool, ) else: - self.tree_cache = RadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool=self.token_to_kv_pool, - disable=server_args.disable_radix_cache, + self.tree_cache = ( + HiRadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + ) + if self.enable_hierarchical_cache + else RadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + disable=server_args.disable_radix_cache, + ) ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) # Init running status self.waiting_queue: List[Req] = [] + self.staging_reqs = {} # The running decoding batch for continuous batching self.running_batch: Optional[ScheduleBatch] = None # The current forward batch @@ -953,6 +962,30 @@ class Scheduler: break req.init_next_round_input(None if prefix_computed else self.tree_cache) + + if self.enable_hierarchical_cache and req.last_node is not None: + if req.last_node.evicted: + # loading KV cache for the request + req.last_node, req.prefix_indices = self.tree_cache.init_load_back( + req.last_node, + req.prefix_indices, + adder.rem_total_tokens, + ) + if req.last_node.loading: + # to prevent frequent cache invalidation + if req.rid in self.staging_reqs: + self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid]) + self.tree_cache.inc_lock_ref(req.last_node) + self.staging_reqs[req.rid] = req.last_node + continue + elif req.last_node.loading: + if not self.tree_cache.loading_complete(req.last_node): + continue + + if req.rid in self.staging_reqs: + self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid]) + del self.staging_reqs[req.rid] + res = adder.add_one_req(req) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py new file mode 100644 index 000000000..4a57eacd1 --- /dev/null +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -0,0 +1,394 @@ +import heapq +import logging +import time +from typing import List, Optional + +import torch + +from sglang.srt.managers.cache_controller import HiCacheController +from sglang.srt.mem_cache.memory_pool import ( + BaseTokenToKVPool, + MLATokenToKVPoolHost, + ReqToTokenPool, +) +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match + +logger = logging.getLogger(__name__) + + +class HiRadixCache(RadixCache): + + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool: BaseTokenToKVPool, + ): + self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool) + self.cache_controller = HiCacheController( + token_to_kv_pool, self.token_to_kv_pool_host + ) + + # record the nodes with ongoing write through + self.ongoing_write_through = {} + # record the node segments with ongoing load back + self.ongoing_load_back = {} + # todo: dynamically adjust the threshold + self.write_through_threshold = 1 + self.load_back_threshold = 10 + super().__init__(req_to_token_pool, token_to_kv_pool, disable=False) + + def reset(self): + TreeNode.counter = 0 + self.cache_controller.reset() + self.token_to_kv_pool_host.clear() + super().reset() + + def get_height(self, node: TreeNode): + height = 0 + while node != self.root_node: + node = node.parent + height += 1 + return height + + def write_backup(self, node: TreeNode): + host_indices = self.cache_controller.write( + device_indices=node.value, + priority=-self.get_height(node), + node_id=node.id, + ) + if host_indices is None: + self.evict_host(len(node.value)) + host_indices = self.cache_controller.write( + device_indices=node.value, + priority=-self.get_height(node), + node_id=node.id, + ) + if host_indices is not None: + node.host_value = host_indices + self.ongoing_write_through[node.id] = node + self.inc_lock_ref(node) + else: + return None + + return len(host_indices) + + def inc_hit_count(self, node: TreeNode): + if self.cache_controller.write_policy != "write_through_selective": + return + node.hit_count += 1 + if node.host_value is None and node.hit_count > self.write_through_threshold: + self.write_backup(node) + node.hit_count = 0 + + def writing_check(self): + while not self.cache_controller.ack_write_queue.empty(): + try: + ack_id = self.cache_controller.ack_write_queue.get_nowait() + self.dec_lock_ref(self.ongoing_write_through[ack_id]) + # clear the reference + del self.ongoing_write_through[ack_id] + except Exception: + break + + def loading_check(self): + while not self.cache_controller.ack_load_queue.empty(): + try: + ack_id = self.cache_controller.ack_load_queue.get_nowait() + start_node, end_node = self.ongoing_load_back[ack_id] + self.dec_lock_ref(end_node) + while end_node != start_node: + assert end_node.loading + end_node.loading = False + end_node = end_node.parent + # clear the reference + del self.ongoing_load_back[ack_id] + except Exception: + break + + def evictable_size(self): + self.writing_check() + self.loading_check() + return self.evictable_size_ + + def evict(self, num_tokens: int, evict_callback=None): + leaves = self._collect_leaves_device() + heapq.heapify(leaves) + + num_evicted = 0 + pending_nodes = [] + while num_evicted < num_tokens and len(leaves): + x = heapq.heappop(leaves) + + if x.lock_ref > 0: + continue + + if x.host_value is None: + if self.cache_controller.write_policy == "write_back": + num_evicted += self.write_backup(x) + elif self.cache_controller.write_policy == "write_through_selective": + num_evicted += self._evict_write_through_selective(x) + else: + assert ( + self.cache_controller.write_policy != "write_through" + ), "write_through should be inclusive" + raise NotImplementedError + else: + num_evicted += self._evict_write_through(x) + + for child in x.parent.children.values(): + if child in pending_nodes: + continue + if not child.evicted: + break + else: + # all children are evicted or no children + heapq.heappush(leaves, x.parent) + + if self.cache_controller.write_policy == "write_back": + # blocking till all write back complete + while len(self.ongoing_write_through) > 0: + self.writing_check() + time.sleep(0.1) + + def _evict_write_through(self, node: TreeNode): + # evict a node already written to host + num_evicted = self.cache_controller.evict_device(node.value, node.host_value) + assert num_evicted > 0 + self.evictable_size_ -= num_evicted + node.value = None + return num_evicted + + def _evict_write_through_selective(self, node: TreeNode): + # evict a node not initiated write to host + self.cache_controller.mem_pool_device.free(node.value) + num_evicted = len(node.value) + self._delete_leaf(node) + return num_evicted + + def evict_host(self, num_tokens: int): + leaves = self._collect_leaves() + heapq.heapify(leaves) + + num_evicted = 0 + while num_evicted < num_tokens and len(leaves): + x = heapq.heappop(leaves) + if x == self.root_node: + break + # only evict the host value of evicted nodes + if not x.evicted: + continue + assert x.lock_ref == 0 and x.host_value is not None + + assert self.cache_controller.evict_host(x.host_value) > 0 + for k, v in x.parent.children.items(): + if v == x: + break + del x.parent.children[k] + + if len(x.parent.children) == 0 and x.parent.evicted: + heapq.heappush(leaves, x.parent) + + def load_back( + self, node: TreeNode, mem_quota: Optional[int] = None + ) -> Optional[torch.Tensor]: + # todo: more loading policies + + last_hit_node = node + nodes_to_load = [] + while node.evicted: + assert ( + node.backuped + ), "No backup available on evicted nodes, should not happen" + nodes_to_load.insert(0, node) + node = node.parent + else: + ancester_node = node + + # protect the ancestor nodes from eviction + delta = self.inc_lock_ref(ancester_node) + + # load it all or not at all + host_indices = torch.cat([n.host_value for n in nodes_to_load]) + if len(host_indices) < self.load_back_threshold or ( + len(host_indices) > mem_quota + delta if mem_quota is not None else False + ): + # skip loading back if the total size is too small or exceeding the memory quota + self.dec_lock_ref(ancester_node) + return None + + device_indices = self.cache_controller.load( + host_indices=host_indices, node_id=last_hit_node.id + ) + if device_indices is None: + self.evict(len(host_indices)) + device_indices = self.cache_controller.load( + host_indices=host_indices, node_id=last_hit_node.id + ) + self.dec_lock_ref(ancester_node) + if device_indices is None: + # no sufficient GPU memory to load back KV caches + return None + + self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node) + offset = 0 + for node in nodes_to_load: + node.value = device_indices[offset : offset + len(node.host_value)] + offset += len(node.host_value) + node.loading = True + self.evictable_size_ += len(device_indices) + self.inc_lock_ref(last_hit_node) + + return device_indices + + def loading_complete(self, node: TreeNode): + self.loading_check() + return node.loading == False + + def init_load_back( + self, + last_node: TreeNode, + prefix_indices: torch.Tensor, + mem_quota: Optional[int] = None, + ): + assert ( + len(prefix_indices) == 0 or prefix_indices.is_cuda + ), "indices of device kV caches should be on GPU" + if last_node.evicted: + loading_values = self.load_back(last_node, mem_quota) + if loading_values is not None: + prefix_indices = ( + loading_values + if len(prefix_indices) == 0 + else torch.cat([prefix_indices, loading_values]) + ) + logger.debug( + f"loading back {len(loading_values)} tokens for node {last_node.id}" + ) + + while last_node.evicted: + last_node = last_node.parent + + return last_node, prefix_indices + + def _match_prefix_helper( + self, node: TreeNode, key: List, value, last_node: TreeNode + ): + node.last_access_time = time.time() + if len(key) == 0: + return + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + if prefix_len < len(child.key): + new_node = self._split_node(child.key, child, prefix_len) + self.inc_hit_count(new_node) + if not new_node.evicted: + value.append(new_node.value) + last_node[0] = new_node + else: + self.inc_hit_count(child) + if not child.evicted: + value.append(child.value) + last_node[0] = child + self._match_prefix_helper(child, key[prefix_len:], value, last_node) + + def _split_node(self, key, child: TreeNode, split_len: int): + # child node split into new_node -> child + new_node = TreeNode() + new_node.children = {key[split_len]: child} + new_node.parent = child.parent + new_node.lock_ref = child.lock_ref + new_node.key = child.key[:split_len] + new_node.loading = child.loading + + # split value and host value if exists + if child.evicted: + new_node.value = None + else: + new_node.value = child.value[:split_len] + child.value = child.value[split_len:] + if child.host_value is not None: + new_node.host_value = child.host_value[:split_len] + child.host_value = child.host_value[split_len:] + child.parent = new_node + child.key = child.key[split_len:] + new_node.parent.children[key[0]] = new_node + return new_node + + def _insert_helper(self, node: TreeNode, key: List, value): + node.last_access_time = time.time() + if len(key) == 0: + return 0 + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + + if prefix_len == len(child.key): + if child.evicted: + # change the reference if the node is evicted + # this often happens in the case of KV cache recomputation + child.value = value[:prefix_len] + self.token_to_kv_pool_host.update_synced(child.host_value) + self.evictable_size_ += len(value[:prefix_len]) + return self._insert_helper( + child, key[prefix_len:], value[prefix_len:] + ) + else: + self.inc_hit_count(child) + return prefix_len + self._insert_helper( + child, key[prefix_len:], value[prefix_len:] + ) + + # partial match, split the node + new_node = self._split_node(child.key, child, prefix_len) + if new_node.evicted: + new_node.value = value[:prefix_len] + self.token_to_kv_pool_host.update_synced(new_node.host_value) + self.evictable_size_ += len(new_node.value) + return self._insert_helper( + new_node, key[prefix_len:], value[prefix_len:] + ) + else: + self.inc_hit_count(new_node) + return prefix_len + self._insert_helper( + new_node, key[prefix_len:], value[prefix_len:] + ) + + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + node.children[key[0]] = new_node + self.evictable_size_ += len(value) + + if self.cache_controller.write_policy == "write_through": + self.write_backup(new_node) + return 0 + + def _collect_leaves_device(self): + def is_leaf(node): + if node.evicted: + return False + if node == self.root_node: + return False + if len(node.children) == 0: + return True + for child in node.children.values(): + if not child.evicted: + return False + return True + + ret_list = [] + stack = [self.root_node] + while stack: + cur_node = stack.pop() + if is_leaf(cur_node): + ret_list.append(cur_node) + else: + for cur_child in cur_node.children.values(): + if not cur_child.evicted: + stack.append(cur_child) + return ret_list diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 7b9b35611..29604d443 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -442,7 +442,7 @@ class MLATokenToKVPoolHost: def __init__( self, device_pool: MHATokenToKVPool, - host_to_device_ratio: float = 2.0, + host_to_device_ratio: float = 4.0, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): @@ -502,6 +502,9 @@ class MLATokenToKVPoolHost: def get_flat_data(self, indices): return self.kv_buffer[:, :, indices] + def assign_flat_data(self, indices, flat_data): + self.kv_buffer[:, :, indices] = flat_data + @debug_timing def transfer(self, indices, flat_data): # backup prepared data from device to host diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 68ad77846..be3b82946 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1289,7 +1289,7 @@ def debug_timing(func): tic.record() result = func(*args, **kwargs) toc.record() - torch.cuda.synchronize() # Ensure all CUDA operations are complete + toc.synchronize() # Wait for the function to complete without synchronizing all ops on the GPU elapsed = tic.elapsed_time(toc) indices = kwargs.get("indices", args[1] if len(args) > 1 else None) num_tokens = len(indices) if indices is not None else 0