Add cache metrics (#119)
This commit is contained in:
50
python/sglang/srt/constrained/base_cache.py
Normal file
50
python/sglang/srt/constrained/base_cache.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Base cache class."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class BaseCache:
|
||||
def __init__(self, enable=True):
|
||||
self.enable = enable
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.cache = {}
|
||||
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
||||
|
||||
def query(self, key):
|
||||
def _init_with_timer(key):
|
||||
start = time.monotonic()
|
||||
val = self.init_value(key)
|
||||
init_time = time.monotonic() - start
|
||||
curr_total = self.metrics["total"]
|
||||
new_total = curr_total + 1
|
||||
|
||||
# Update average init time without old_avg * old_total to avoid overflow.
|
||||
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
||||
curr_total / new_total
|
||||
) * self.metrics["avg_init_time"]
|
||||
self.metrics["total"] += 1
|
||||
return val
|
||||
|
||||
if key in self.cache:
|
||||
self.metrics["hit"] += 1
|
||||
val = self.cache[key]
|
||||
else:
|
||||
# Cache miss or disabled.
|
||||
val = _init_with_timer(key)
|
||||
|
||||
if self.enable:
|
||||
self.cache[key] = val
|
||||
return val
|
||||
|
||||
def init_value(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_hit_rate(self):
|
||||
if self.metrics["total"] == 0:
|
||||
return 0
|
||||
return self.metrics["hit"] / self.metrics["total"]
|
||||
|
||||
def get_avg_init_time(self):
|
||||
return self.metrics["avg_init_time"]
|
||||
@@ -1,4 +1,5 @@
|
||||
import interegular
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.disk_cache import disk_cache
|
||||
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||
|
||||
@@ -56,15 +57,12 @@ class FastForwardMap:
|
||||
return fast_forward_str, next_state
|
||||
|
||||
|
||||
class FastForwardCache:
|
||||
class FastForwardCache(BaseCache):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
super().__init__()
|
||||
|
||||
def init_fast_forward_map(self, regex_string):
|
||||
if regex_string not in self.cache:
|
||||
fast_forward_map = FastForwardMap(regex_string)
|
||||
self.cache[regex_string] = fast_forward_map
|
||||
return self.cache[regex_string]
|
||||
def init_value(self, regex):
|
||||
return FastForwardMap(regex)
|
||||
|
||||
|
||||
def test_main():
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
_enable_memory_cache = True
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
||||
self.cache = {}
|
||||
class FSMCache(BaseCache):
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||
super().__init__(enable=enable)
|
||||
self.outlines_tokenizer = TransformerTokenizer(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
)
|
||||
|
||||
def init_fsm(self, regex):
|
||||
if _enable_memory_cache:
|
||||
if regex not in self.cache:
|
||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||
self.cache[regex] = fsm
|
||||
return self.cache[regex]
|
||||
|
||||
def init_value(self, regex):
|
||||
return RegexFSM(regex, self.outlines_tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user