Add cache metrics (#119)
This commit is contained in:
50
python/sglang/srt/constrained/base_cache.py
Normal file
50
python/sglang/srt/constrained/base_cache.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Base cache class."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class BaseCache:
|
||||
def __init__(self, enable=True):
|
||||
self.enable = enable
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.cache = {}
|
||||
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
||||
|
||||
def query(self, key):
|
||||
def _init_with_timer(key):
|
||||
start = time.monotonic()
|
||||
val = self.init_value(key)
|
||||
init_time = time.monotonic() - start
|
||||
curr_total = self.metrics["total"]
|
||||
new_total = curr_total + 1
|
||||
|
||||
# Update average init time without old_avg * old_total to avoid overflow.
|
||||
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
||||
curr_total / new_total
|
||||
) * self.metrics["avg_init_time"]
|
||||
self.metrics["total"] += 1
|
||||
return val
|
||||
|
||||
if key in self.cache:
|
||||
self.metrics["hit"] += 1
|
||||
val = self.cache[key]
|
||||
else:
|
||||
# Cache miss or disabled.
|
||||
val = _init_with_timer(key)
|
||||
|
||||
if self.enable:
|
||||
self.cache[key] = val
|
||||
return val
|
||||
|
||||
def init_value(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_hit_rate(self):
|
||||
if self.metrics["total"] == 0:
|
||||
return 0
|
||||
return self.metrics["hit"] / self.metrics["total"]
|
||||
|
||||
def get_avg_init_time(self):
|
||||
return self.metrics["avg_init_time"]
|
||||
@@ -1,4 +1,5 @@
|
||||
import interegular
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.disk_cache import disk_cache
|
||||
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||
|
||||
@@ -56,15 +57,12 @@ class FastForwardMap:
|
||||
return fast_forward_str, next_state
|
||||
|
||||
|
||||
class FastForwardCache:
|
||||
class FastForwardCache(BaseCache):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
super().__init__()
|
||||
|
||||
def init_fast_forward_map(self, regex_string):
|
||||
if regex_string not in self.cache:
|
||||
fast_forward_map = FastForwardMap(regex_string)
|
||||
self.cache[regex_string] = fast_forward_map
|
||||
return self.cache[regex_string]
|
||||
def init_value(self, regex):
|
||||
return FastForwardMap(regex)
|
||||
|
||||
|
||||
def test_main():
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
_enable_memory_cache = True
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
||||
self.cache = {}
|
||||
class FSMCache(BaseCache):
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||
super().__init__(enable=enable)
|
||||
self.outlines_tokenizer = TransformerTokenizer(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
)
|
||||
|
||||
def init_fsm(self, regex):
|
||||
if _enable_memory_cache:
|
||||
if regex not in self.cache:
|
||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||
self.cache[regex] = fsm
|
||||
return self.cache[regex]
|
||||
|
||||
def init_value(self, regex):
|
||||
return RegexFSM(regex, self.outlines_tokenizer)
|
||||
|
||||
@@ -60,7 +60,11 @@ class Req:
|
||||
|
||||
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"):
|
||||
# FIXME: This logic does not really solve the problem of determining whether
|
||||
# there should be a leading space.
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
||||
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token
|
||||
if first_token.startswith("▁"):
|
||||
old_output_str = " " + old_output_str
|
||||
new_input_string = (
|
||||
self.input_text
|
||||
|
||||
@@ -4,8 +4,7 @@ import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import rpyc
|
||||
@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = Scheduler(
|
||||
self.schedule_heuristic,
|
||||
self.max_num_running_seq,
|
||||
@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.regex_fsm_cache.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
if not self.no_regex_fast_forward:
|
||||
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
|
||||
req.fast_forward_map = self.fast_forward_cache.query(
|
||||
req.sampling_params.regex
|
||||
)
|
||||
|
||||
@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
|
||||
can_run_list = []
|
||||
new_batch_total_tokens = 0
|
||||
new_batch_input_tokens = 0
|
||||
new_batch_prefix_tokens = 0
|
||||
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
|
||||
return None
|
||||
|
||||
if self.tp_rank == 0:
|
||||
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs)
|
||||
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
||||
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9
|
||||
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
||||
tree_cache_hit_rate = (
|
||||
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||
)
|
||||
logger.info(
|
||||
f"new fill batch. #seq: {len(can_run_list)}. "
|
||||
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
|
||||
f"#cached_token: {hit_tokens}. "
|
||||
f"#new_token: {new_batch_input_tokens}. "
|
||||
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
||||
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
f"#running_req: {running_req}. "
|
||||
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
||||
)
|
||||
logger.debug(
|
||||
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. "
|
||||
)
|
||||
|
||||
new_batch = Batch.init_new(
|
||||
|
||||
Reference in New Issue
Block a user