From 81561f8e2d55d105aabbe0eab1b3b33f4fc04b0b Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Fri, 26 Jan 2024 13:32:59 +0800 Subject: [PATCH] Flush Cache API (#103) --- python/sglang/flush_cache.py | 60 ------------------- python/sglang/srt/managers/io_struct.py | 5 ++ .../sglang/srt/managers/router/model_rpc.py | 24 +++++++- .../sglang/srt/managers/router/radix_cache.py | 9 ++- .../sglang/srt/managers/tokenizer_manager.py | 5 ++ python/sglang/srt/server.py | 9 +++ 6 files changed, 48 insertions(+), 64 deletions(-) delete mode 100644 python/sglang/flush_cache.py diff --git a/python/sglang/flush_cache.py b/python/sglang/flush_cache.py deleted file mode 100644 index 6050ee22c..000000000 --- a/python/sglang/flush_cache.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Flush cache in the backend by sending random requests.""" -import argparse -import random -import string -import time - -from sglang.test.test_utils import ( - add_common_sglang_args_and_parse, - select_sglang_backend, -) - -import sglang as sgl - - -@sgl.function -def flush_radix_cache(s, prompt): - s += prompt + sgl.gen("flush", max_tokens=1, stop="END") - - -def main(args, max_total_tokens, context_length, print_flag): - backend = select_sglang_backend(args) - flush_length = int(context_length * 0.8) - batch_size = int(max_total_tokens / flush_length) - prompt_length = flush_length * 2 - prompts = [ - " ".join(random.choices(string.ascii_letters, k=int(prompt_length))) - for _ in range(batch_size) - ] - arguments = [{"prompt": prompts[i]} for i in range(batch_size)] - - start_time = time.time() - flush_radix_cache.run_batch( - arguments, temperature=0, backend=backend, num_threads=1 - ) - end_time = time.time() - - if print_flag: - print( - f"Flush length: {flush_length}\n", - f"Prompt length: {prompt_length}\n", - f"Total Prompt letters: {batch_size * prompt_length}\n", - f"Flush radix cache latency: {end_time - start_time:.3f}", - sep="", - ) - - # to prevent the backend still running - time.sleep(1) - - -def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False): - main(args, max_total_tokens, context_length, print_flag=print_flag) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--max-total-tokens", type=int, default=20000) - parser.add_argument("--context-length", type=int, default=1024) - args = add_common_sglang_args_and_parse(parser) - random.seed(0) - main(args, args.max_total_tokens, args.context_length, print_flag=True) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c4380c49a..6b6940d1c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -87,3 +87,8 @@ class BatchStrOut: output_str: List[str] meta_info: List[Dict] finished: List[bool] + + +@dataclass +class FlushCacheReq: + pass diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index dfaa8f12b..199a8974b 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer from sglang.srt.constrained.fast_forward import FastForwardCache from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput +from sglang.srt.managers.io_struct import ( + BatchTokenIDOut, + TokenizedGenerateReqInput, + FlushCacheReq, +) from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.radix_cache import RadixCache @@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service): self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0) self.new_token_ratio_step = (0.0001, 0.05) # (down, up) + def flush_cache(self): + if len(self.forward_queue) == 0 and ( + self.running_batch is None or len(self.running_batch.reqs) == 0 + ): + self.tree_cache.reset() + self.req_to_token_pool.clear() + self.token_to_kv_pool.clear() + torch.cuda.empty_cache() + logger.info("Cache flushed successfully!") + else: + warnings.warn( + "Cache not flushed because there are pending requests. " + f"#queue-req: {len(self.forward_queue)}, " + f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" + ) + def exposed_step(self, recv_reqs): if self.tp_size != 1: recv_reqs = obtain(recv_reqs) @@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service): for recv_req in recv_reqs: if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) + elif isinstance(recv_req, FlushCacheReq): + self.flush_cache() else: raise ValueError(f"Invalid request: {recv_req}") diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/router/radix_cache.py index 25043d7ed..6ee670309 100644 --- a/python/sglang/srt/managers/router/radix_cache.py +++ b/python/sglang/srt/managers/router/radix_cache.py @@ -30,14 +30,17 @@ def match(key, seq): class RadixCache: def __init__(self, disable=False): + self.reset() + self.disable = disable + + ##### Public API ##### + + def reset(self): self.root_node = TreeNode() self.root_node.value = [] self.root_node.ref_counter = 1 self.evictable_size_ = 0 - self.disable = disable - - ##### Public API ##### def match_prefix(self, key): if self.disable: return [], self.root_node diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 85a7a04f6..d08b33634 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import ( BatchStrOut, GenerateReqInput, TokenizedGenerateReqInput, + FlushCacheReq, ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams @@ -228,6 +229,10 @@ class TokenizerManager: yield output_list + async def flush_cache(self): + flush_cache_req = FlushCacheReq() + self.send_to_router.send_pyobj(flush_cache_req) + async def create_handle_loop(self): self.to_create_loop = False loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 28a930416..9750c4e72 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -71,6 +71,15 @@ async def get_model_info(): return result +@app.get("/flush_cache") +async def flush_cache(): + await tokenizer_manager.flush_cache() + return Response( + content="Cache flushed.\nPlease check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)\n", + status_code=200, + ) + + async def stream_generator(obj): async for out in tokenizer_manager.generate_request(obj): yield out