Add cache metrics (#119)
This commit is contained in:
50
python/sglang/srt/constrained/base_cache.py
Normal file
50
python/sglang/srt/constrained/base_cache.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Base cache class."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCache:
|
||||||
|
def __init__(self, enable=True):
|
||||||
|
self.enable = enable
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.cache = {}
|
||||||
|
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
||||||
|
|
||||||
|
def query(self, key):
|
||||||
|
def _init_with_timer(key):
|
||||||
|
start = time.monotonic()
|
||||||
|
val = self.init_value(key)
|
||||||
|
init_time = time.monotonic() - start
|
||||||
|
curr_total = self.metrics["total"]
|
||||||
|
new_total = curr_total + 1
|
||||||
|
|
||||||
|
# Update average init time without old_avg * old_total to avoid overflow.
|
||||||
|
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
||||||
|
curr_total / new_total
|
||||||
|
) * self.metrics["avg_init_time"]
|
||||||
|
self.metrics["total"] += 1
|
||||||
|
return val
|
||||||
|
|
||||||
|
if key in self.cache:
|
||||||
|
self.metrics["hit"] += 1
|
||||||
|
val = self.cache[key]
|
||||||
|
else:
|
||||||
|
# Cache miss or disabled.
|
||||||
|
val = _init_with_timer(key)
|
||||||
|
|
||||||
|
if self.enable:
|
||||||
|
self.cache[key] = val
|
||||||
|
return val
|
||||||
|
|
||||||
|
def init_value(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_cache_hit_rate(self):
|
||||||
|
if self.metrics["total"] == 0:
|
||||||
|
return 0
|
||||||
|
return self.metrics["hit"] / self.metrics["total"]
|
||||||
|
|
||||||
|
def get_avg_init_time(self):
|
||||||
|
return self.metrics["avg_init_time"]
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import interegular
|
import interegular
|
||||||
|
from sglang.srt.constrained.base_cache import BaseCache
|
||||||
from sglang.srt.constrained.disk_cache import disk_cache
|
from sglang.srt.constrained.disk_cache import disk_cache
|
||||||
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||||
|
|
||||||
@@ -56,15 +57,12 @@ class FastForwardMap:
|
|||||||
return fast_forward_str, next_state
|
return fast_forward_str, next_state
|
||||||
|
|
||||||
|
|
||||||
class FastForwardCache:
|
class FastForwardCache(BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache = {}
|
super().__init__()
|
||||||
|
|
||||||
def init_fast_forward_map(self, regex_string):
|
def init_value(self, regex):
|
||||||
if regex_string not in self.cache:
|
return FastForwardMap(regex)
|
||||||
fast_forward_map = FastForwardMap(regex_string)
|
|
||||||
self.cache[regex_string] = fast_forward_map
|
|
||||||
return self.cache[regex_string]
|
|
||||||
|
|
||||||
|
|
||||||
def test_main():
|
def test_main():
|
||||||
|
|||||||
@@ -1,21 +1,14 @@
|
|||||||
|
from sglang.srt.constrained.base_cache import BaseCache
|
||||||
from sglang.srt.constrained.fsm import RegexFSM
|
from sglang.srt.constrained.fsm import RegexFSM
|
||||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||||
|
|
||||||
_enable_memory_cache = True
|
|
||||||
|
|
||||||
|
class FSMCache(BaseCache):
|
||||||
class FSMCache:
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||||
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
super().__init__(enable=enable)
|
||||||
self.cache = {}
|
|
||||||
self.outlines_tokenizer = TransformerTokenizer(
|
self.outlines_tokenizer = TransformerTokenizer(
|
||||||
tokenizer_path, **tokenizer_args_dict
|
tokenizer_path, **tokenizer_args_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_fsm(self, regex):
|
def init_value(self, regex):
|
||||||
if _enable_memory_cache:
|
|
||||||
if regex not in self.cache:
|
|
||||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
|
||||||
self.cache[regex] = fsm
|
|
||||||
return self.cache[regex]
|
|
||||||
|
|
||||||
return RegexFSM(regex, self.outlines_tokenizer)
|
return RegexFSM(regex, self.outlines_tokenizer)
|
||||||
|
|||||||
@@ -60,7 +60,11 @@ class Req:
|
|||||||
|
|
||||||
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
||||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||||
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"):
|
# FIXME: This logic does not really solve the problem of determining whether
|
||||||
|
# there should be a leading space.
|
||||||
|
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
||||||
|
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token
|
||||||
|
if first_token.startswith("▁"):
|
||||||
old_output_str = " " + old_output_str
|
old_output_str = " " + old_output_str
|
||||||
new_input_string = (
|
new_input_string = (
|
||||||
self.input_text
|
self.input_text
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ import multiprocessing
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum, auto
|
from typing import List
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rpyc
|
import rpyc
|
||||||
@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
|
|
||||||
# Init cache
|
# Init cache
|
||||||
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
||||||
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
self.scheduler = Scheduler(
|
self.scheduler = Scheduler(
|
||||||
self.schedule_heuristic,
|
self.schedule_heuristic,
|
||||||
self.max_num_running_seq,
|
self.max_num_running_seq,
|
||||||
@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||||
):
|
):
|
||||||
self.tree_cache.reset()
|
self.tree_cache.reset()
|
||||||
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
|
self.regex_fsm_cache.reset()
|
||||||
self.req_to_token_pool.clear()
|
self.req_to_token_pool.clear()
|
||||||
self.token_to_kv_pool.clear()
|
self.token_to_kv_pool.clear()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
|
|
||||||
# Init regex fsm
|
# Init regex fsm
|
||||||
if req.sampling_params.regex is not None:
|
if req.sampling_params.regex is not None:
|
||||||
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||||
if not self.no_regex_fast_forward:
|
if not self.no_regex_fast_forward:
|
||||||
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
|
req.fast_forward_map = self.fast_forward_cache.query(
|
||||||
req.sampling_params.regex
|
req.sampling_params.regex
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
can_run_list = []
|
can_run_list = []
|
||||||
new_batch_total_tokens = 0
|
new_batch_total_tokens = 0
|
||||||
new_batch_input_tokens = 0
|
new_batch_input_tokens = 0
|
||||||
new_batch_prefix_tokens = 0
|
|
||||||
|
|
||||||
available_size = (
|
available_size = (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
|
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs)
|
||||||
|
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
||||||
|
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9
|
||||||
|
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
||||||
|
tree_cache_hit_rate = (
|
||||||
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"new fill batch. #seq: {len(can_run_list)}. "
|
f"new fill batch. #seq: {len(can_run_list)}. "
|
||||||
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
|
f"#cached_token: {hit_tokens}. "
|
||||||
f"#new_token: {new_batch_input_tokens}. "
|
f"#new_token: {new_batch_input_tokens}. "
|
||||||
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
||||||
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
f"#running_req: {running_req}. "
|
||||||
|
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||||
|
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||||
|
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||||
|
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. "
|
||||||
)
|
)
|
||||||
|
|
||||||
new_batch = Batch.init_new(
|
new_batch = Batch.init_new(
|
||||||
|
|||||||
Reference in New Issue
Block a user