Flush Cache API (#103)
This commit is contained in:
@@ -87,3 +87,8 @@ class BatchStrOut:
|
||||
output_str: List[str]
|
||||
meta_info: List[Dict]
|
||||
finished: List[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlushCacheReq:
|
||||
pass
|
||||
|
||||
@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer
|
||||
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
TokenizedGenerateReqInput,
|
||||
FlushCacheReq,
|
||||
)
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.forward_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cache flushed successfully!")
|
||||
else:
|
||||
warnings.warn(
|
||||
"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service):
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
|
||||
@@ -30,14 +30,17 @@ def match(key, seq):
|
||||
|
||||
class RadixCache:
|
||||
def __init__(self, disable=False):
|
||||
self.reset()
|
||||
self.disable = disable
|
||||
|
||||
##### Public API #####
|
||||
|
||||
def reset(self):
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.value = []
|
||||
self.root_node.ref_counter = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
self.disable = disable
|
||||
|
||||
##### Public API #####
|
||||
def match_prefix(self, key):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
FlushCacheReq,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
@@ -228,6 +229,10 @@ class TokenizerManager:
|
||||
|
||||
yield output_list
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
|
||||
async def create_handle_loop(self):
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -71,6 +71,15 @@ async def get_model_info():
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/flush_cache")
|
||||
async def flush_cache():
|
||||
await tokenizer_manager.flush_cache()
|
||||
return Response(
|
||||
content="Cache flushed.\nPlease check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
async def stream_generator(obj):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
yield out
|
||||
|
||||
Reference in New Issue
Block a user