Flush Cache API (#103)
This commit is contained in:
@@ -1,60 +0,0 @@
|
|||||||
"""Flush cache in the backend by sending random requests."""
|
|
||||||
import argparse
|
|
||||||
import random
|
|
||||||
import string
|
|
||||||
import time
|
|
||||||
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
add_common_sglang_args_and_parse,
|
|
||||||
select_sglang_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
import sglang as sgl
|
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
|
||||||
def flush_radix_cache(s, prompt):
|
|
||||||
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
|
|
||||||
|
|
||||||
|
|
||||||
def main(args, max_total_tokens, context_length, print_flag):
|
|
||||||
backend = select_sglang_backend(args)
|
|
||||||
flush_length = int(context_length * 0.8)
|
|
||||||
batch_size = int(max_total_tokens / flush_length)
|
|
||||||
prompt_length = flush_length * 2
|
|
||||||
prompts = [
|
|
||||||
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
|
|
||||||
for _ in range(batch_size)
|
|
||||||
]
|
|
||||||
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
flush_radix_cache.run_batch(
|
|
||||||
arguments, temperature=0, backend=backend, num_threads=1
|
|
||||||
)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
if print_flag:
|
|
||||||
print(
|
|
||||||
f"Flush length: {flush_length}\n",
|
|
||||||
f"Prompt length: {prompt_length}\n",
|
|
||||||
f"Total Prompt letters: {batch_size * prompt_length}\n",
|
|
||||||
f"Flush radix cache latency: {end_time - start_time:.3f}",
|
|
||||||
sep="",
|
|
||||||
)
|
|
||||||
|
|
||||||
# to prevent the backend still running
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
|
|
||||||
main(args, max_total_tokens, context_length, print_flag=print_flag)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--max-total-tokens", type=int, default=20000)
|
|
||||||
parser.add_argument("--context-length", type=int, default=1024)
|
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
|
||||||
random.seed(0)
|
|
||||||
main(args, args.max_total_tokens, args.context_length, print_flag=True)
|
|
||||||
@@ -87,3 +87,8 @@ class BatchStrOut:
|
|||||||
output_str: List[str]
|
output_str: List[str]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished: List[bool]
|
finished: List[bool]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlushCacheReq:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer
|
|||||||
from sglang.srt.constrained.fast_forward import FastForwardCache
|
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
from sglang.srt.managers.io_struct import (
|
||||||
|
BatchTokenIDOut,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
FlushCacheReq,
|
||||||
|
)
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||||
@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||||
|
|
||||||
|
def flush_cache(self):
|
||||||
|
if len(self.forward_queue) == 0 and (
|
||||||
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||||
|
):
|
||||||
|
self.tree_cache.reset()
|
||||||
|
self.req_to_token_pool.clear()
|
||||||
|
self.token_to_kv_pool.clear()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.info("Cache flushed successfully!")
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"Cache not flushed because there are pending requests. "
|
||||||
|
f"#queue-req: {len(self.forward_queue)}, "
|
||||||
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||||
|
)
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
if self.tp_size != 1:
|
if self.tp_size != 1:
|
||||||
recv_reqs = obtain(recv_reqs)
|
recv_reqs = obtain(recv_reqs)
|
||||||
@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
|
self.flush_cache()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
|
|||||||
@@ -30,14 +30,17 @@ def match(key, seq):
|
|||||||
|
|
||||||
class RadixCache:
|
class RadixCache:
|
||||||
def __init__(self, disable=False):
|
def __init__(self, disable=False):
|
||||||
|
self.reset()
|
||||||
|
self.disable = disable
|
||||||
|
|
||||||
|
##### Public API #####
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
self.root_node = TreeNode()
|
self.root_node = TreeNode()
|
||||||
self.root_node.value = []
|
self.root_node.value = []
|
||||||
self.root_node.ref_counter = 1
|
self.root_node.ref_counter = 1
|
||||||
self.evictable_size_ = 0
|
self.evictable_size_ = 0
|
||||||
|
|
||||||
self.disable = disable
|
|
||||||
|
|
||||||
##### Public API #####
|
|
||||||
def match_prefix(self, key):
|
def match_prefix(self, key):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return [], self.root_node
|
return [], self.root_node
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
FlushCacheReq,
|
||||||
)
|
)
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
@@ -228,6 +229,10 @@ class TokenizerManager:
|
|||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
|
async def flush_cache(self):
|
||||||
|
flush_cache_req = FlushCacheReq()
|
||||||
|
self.send_to_router.send_pyobj(flush_cache_req)
|
||||||
|
|
||||||
async def create_handle_loop(self):
|
async def create_handle_loop(self):
|
||||||
self.to_create_loop = False
|
self.to_create_loop = False
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@@ -71,6 +71,15 @@ async def get_model_info():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/flush_cache")
|
||||||
|
async def flush_cache():
|
||||||
|
await tokenizer_manager.flush_cache()
|
||||||
|
return Response(
|
||||||
|
content="Cache flushed.\nPlease check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(obj):
|
async def stream_generator(obj):
|
||||||
async for out in tokenizer_manager.generate_request(obj):
|
async for out in tokenizer_manager.generate_request(obj):
|
||||||
yield out
|
yield out
|
||||||
|
|||||||
Reference in New Issue
Block a user