Flush Cache API (#103)
This commit is contained in:
@@ -1,60 +0,0 @@
|
||||
"""Flush cache in the backend by sending random requests."""
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def flush_radix_cache(s, prompt):
|
||||
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
|
||||
|
||||
|
||||
def main(args, max_total_tokens, context_length, print_flag):
|
||||
backend = select_sglang_backend(args)
|
||||
flush_length = int(context_length * 0.8)
|
||||
batch_size = int(max_total_tokens / flush_length)
|
||||
prompt_length = flush_length * 2
|
||||
prompts = [
|
||||
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
|
||||
|
||||
start_time = time.time()
|
||||
flush_radix_cache.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=1
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if print_flag:
|
||||
print(
|
||||
f"Flush length: {flush_length}\n",
|
||||
f"Prompt length: {prompt_length}\n",
|
||||
f"Total Prompt letters: {batch_size * prompt_length}\n",
|
||||
f"Flush radix cache latency: {end_time - start_time:.3f}",
|
||||
sep="",
|
||||
)
|
||||
|
||||
# to prevent the backend still running
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
|
||||
main(args, max_total_tokens, context_length, print_flag=print_flag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-total-tokens", type=int, default=20000)
|
||||
parser.add_argument("--context-length", type=int, default=1024)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
random.seed(0)
|
||||
main(args, args.max_total_tokens, args.context_length, print_flag=True)
|
||||
@@ -87,3 +87,8 @@ class BatchStrOut:
|
||||
output_str: List[str]
|
||||
meta_info: List[Dict]
|
||||
finished: List[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlushCacheReq:
|
||||
pass
|
||||
|
||||
@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer
|
||||
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
TokenizedGenerateReqInput,
|
||||
FlushCacheReq,
|
||||
)
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.forward_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cache flushed successfully!")
|
||||
else:
|
||||
warnings.warn(
|
||||
"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service):
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
|
||||
@@ -30,14 +30,17 @@ def match(key, seq):
|
||||
|
||||
class RadixCache:
|
||||
def __init__(self, disable=False):
|
||||
self.reset()
|
||||
self.disable = disable
|
||||
|
||||
##### Public API #####
|
||||
|
||||
def reset(self):
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.value = []
|
||||
self.root_node.ref_counter = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
self.disable = disable
|
||||
|
||||
##### Public API #####
|
||||
def match_prefix(self, key):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
FlushCacheReq,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
@@ -228,6 +229,10 @@ class TokenizerManager:
|
||||
|
||||
yield output_list
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
|
||||
async def create_handle_loop(self):
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -71,6 +71,15 @@ async def get_model_info():
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/flush_cache")
|
||||
async def flush_cache():
|
||||
await tokenizer_manager.flush_cache()
|
||||
return Response(
|
||||
content="Cache flushed.\nPlease check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
async def stream_generator(obj):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
yield out
|
||||
|
||||
Reference in New Issue
Block a user