Add cache metrics (#119)

This commit is contained in:
Cody Yu
2024-01-30 22:13:14 -08:00
committed by GitHub
parent 74b3bfaaf8
commit 71b54eea7d
5 changed files with 87 additions and 27 deletions

View File

@@ -0,0 +1,50 @@
"""Base cache class."""
import time
class BaseCache:
def __init__(self, enable=True):
self.enable = enable
self.reset()
def reset(self):
self.cache = {}
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
def query(self, key):
def _init_with_timer(key):
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
curr_total = self.metrics["total"]
new_total = curr_total + 1
# Update average init time without old_avg * old_total to avoid overflow.
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
self.metrics["total"] += 1
return val
if key in self.cache:
self.metrics["hit"] += 1
val = self.cache[key]
else:
# Cache miss or disabled.
val = _init_with_timer(key)
if self.enable:
self.cache[key] = val
return val
def init_value(self, key):
raise NotImplementedError
def get_cache_hit_rate(self):
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
def get_avg_init_time(self):
return self.metrics["avg_init_time"]

View File

@@ -1,4 +1,5 @@
import interegular import interegular
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.disk_cache import disk_cache from sglang.srt.constrained.disk_cache import disk_cache
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
@@ -56,15 +57,12 @@ class FastForwardMap:
return fast_forward_str, next_state return fast_forward_str, next_state
class FastForwardCache: class FastForwardCache(BaseCache):
def __init__(self): def __init__(self):
self.cache = {} super().__init__()
def init_fast_forward_map(self, regex_string): def init_value(self, regex):
if regex_string not in self.cache: return FastForwardMap(regex)
fast_forward_map = FastForwardMap(regex_string)
self.cache[regex_string] = fast_forward_map
return self.cache[regex_string]
def test_main(): def test_main():

View File

@@ -1,21 +1,14 @@
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.fsm import RegexFSM from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer from sglang.srt.constrained.tokenizer import TransformerTokenizer
_enable_memory_cache = True
class FSMCache(BaseCache):
class FSMCache: def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
def __init__(self, tokenizer_path, tokenizer_args_dict): super().__init__(enable=enable)
self.cache = {}
self.outlines_tokenizer = TransformerTokenizer( self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict tokenizer_path, **tokenizer_args_dict
) )
def init_fsm(self, regex): def init_value(self, regex):
if _enable_memory_cache:
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
return self.cache[regex]
return RegexFSM(regex, self.outlines_tokenizer) return RegexFSM(regex, self.outlines_tokenizer)

View File

@@ -60,7 +60,11 @@ class Req:
def tokenize_fast_forward(self, fast_forward_str, next_state): def tokenize_fast_forward(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids) old_output_str = self.tokenizer.decode(self.output_ids)
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith(""): # FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token
if first_token.startswith(""):
old_output_str = " " + old_output_str old_output_str = " " + old_output_str
new_input_string = ( new_input_string = (
self.input_text self.input_text

View File

@@ -4,8 +4,7 @@ import multiprocessing
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto from typing import List
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import rpyc import rpyc
@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
# Init cache # Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode) self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler( self.scheduler = Scheduler(
self.schedule_heuristic, self.schedule_heuristic,
self.max_num_running_seq, self.max_num_running_seq,
@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
self.running_batch is None or len(self.running_batch.reqs) == 0 self.running_batch is None or len(self.running_batch.reqs) == 0
): ):
self.tree_cache.reset() self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool.clear() self.token_to_kv_pool.clear()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm # Init regex fsm
if req.sampling_params.regex is not None: if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.no_regex_fast_forward: if not self.no_regex_fast_forward:
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map( req.fast_forward_map = self.fast_forward_cache.query(
req.sampling_params.regex req.sampling_params.regex
) )
@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
can_run_list = [] can_run_list = []
new_batch_total_tokens = 0 new_batch_total_tokens = 0
new_batch_input_tokens = 0 new_batch_input_tokens = 0
new_batch_prefix_tokens = 0
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
return None return None
if self.tp_rank == 0: if self.tp_rank == 0:
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
logger.info( logger.info(
f"new fill batch. #seq: {len(can_run_list)}. " f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. " f"#cached_token: {hit_tokens}. "
f"#new_token: {new_batch_input_tokens}. " f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. "
) )
new_batch = Batch.init_new( new_batch = Batch.init_new(