From 71b54eea7d21a2bb1d8ef340e7002983a29b1d5f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 30 Jan 2024 22:13:14 -0800 Subject: [PATCH] Add cache metrics (#119) --- python/sglang/srt/constrained/base_cache.py | 50 +++++++++++++++++++ python/sglang/srt/constrained/fast_forward.py | 12 ++--- python/sglang/srt/constrained/fsm_cache.py | 17 ++----- .../sglang/srt/managers/router/infer_batch.py | 6 ++- .../sglang/srt/managers/router/model_rpc.py | 29 ++++++++--- 5 files changed, 87 insertions(+), 27 deletions(-) create mode 100644 python/sglang/srt/constrained/base_cache.py diff --git a/python/sglang/srt/constrained/base_cache.py b/python/sglang/srt/constrained/base_cache.py new file mode 100644 index 000000000..fb04311d2 --- /dev/null +++ b/python/sglang/srt/constrained/base_cache.py @@ -0,0 +1,50 @@ +"""Base cache class.""" + +import time + + +class BaseCache: + def __init__(self, enable=True): + self.enable = enable + self.reset() + + def reset(self): + self.cache = {} + self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0} + + def query(self, key): + def _init_with_timer(key): + start = time.monotonic() + val = self.init_value(key) + init_time = time.monotonic() - start + curr_total = self.metrics["total"] + new_total = curr_total + 1 + + # Update average init time without old_avg * old_total to avoid overflow. + self.metrics["avg_init_time"] = (init_time / new_total) + ( + curr_total / new_total + ) * self.metrics["avg_init_time"] + self.metrics["total"] += 1 + return val + + if key in self.cache: + self.metrics["hit"] += 1 + val = self.cache[key] + else: + # Cache miss or disabled. + val = _init_with_timer(key) + + if self.enable: + self.cache[key] = val + return val + + def init_value(self, key): + raise NotImplementedError + + def get_cache_hit_rate(self): + if self.metrics["total"] == 0: + return 0 + return self.metrics["hit"] / self.metrics["total"] + + def get_avg_init_time(self): + return self.metrics["avg_init_time"] diff --git a/python/sglang/srt/constrained/fast_forward.py b/python/sglang/srt/constrained/fast_forward.py index 49ac33ea5..d6bb94cb9 100644 --- a/python/sglang/srt/constrained/fast_forward.py +++ b/python/sglang/srt/constrained/fast_forward.py @@ -1,4 +1,5 @@ import interegular +from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.disk_cache import disk_cache from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm @@ -56,15 +57,12 @@ class FastForwardMap: return fast_forward_str, next_state -class FastForwardCache: +class FastForwardCache(BaseCache): def __init__(self): - self.cache = {} + super().__init__() - def init_fast_forward_map(self, regex_string): - if regex_string not in self.cache: - fast_forward_map = FastForwardMap(regex_string) - self.cache[regex_string] = fast_forward_map - return self.cache[regex_string] + def init_value(self, regex): + return FastForwardMap(regex) def test_main(): diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 7bd815eb1..9acf0f6eb 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -1,21 +1,14 @@ +from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.fsm import RegexFSM from sglang.srt.constrained.tokenizer import TransformerTokenizer -_enable_memory_cache = True - -class FSMCache: - def __init__(self, tokenizer_path, tokenizer_args_dict): - self.cache = {} +class FSMCache(BaseCache): + def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): + super().__init__(enable=enable) self.outlines_tokenizer = TransformerTokenizer( tokenizer_path, **tokenizer_args_dict ) - def init_fsm(self, regex): - if _enable_memory_cache: - if regex not in self.cache: - fsm = RegexFSM(regex, self.outlines_tokenizer) - self.cache[regex] = fsm - return self.cache[regex] - + def init_value(self, regex): return RegexFSM(regex, self.outlines_tokenizer) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index c5aa88615..3a9b88555 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -60,7 +60,11 @@ class Req: def tokenize_fast_forward(self, fast_forward_str, next_state): old_output_str = self.tokenizer.decode(self.output_ids) - if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"): + # FIXME: This logic does not really solve the problem of determining whether + # there should be a leading space. + first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0]) + first_token = first_token.decode() if isinstance(first_token, bytes) else first_token + if first_token.startswith("▁"): old_output_str = " " + old_output_str new_input_string = ( self.input_text diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index eb5fc2f43..88ba48949 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -4,8 +4,7 @@ import multiprocessing import time import warnings from concurrent.futures import ThreadPoolExecutor -from enum import Enum, auto -from typing import Dict, List, Optional, Tuple, Union +from typing import List import numpy as np import rpyc @@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service): # Init cache self.tree_cache = RadixCache(disable="no-cache" in self.model_mode) + self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = Scheduler( self.schedule_heuristic, self.max_num_running_seq, @@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service): self.running_batch is None or len(self.running_batch.reqs) == 0 ): self.tree_cache.reset() + self.tree_cache_metrics = {"total": 0, "hit": 0} + self.regex_fsm_cache.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache() @@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service): # Init regex fsm if req.sampling_params.regex is not None: - req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex) + req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) if not self.no_regex_fast_forward: - req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map( + req.fast_forward_map = self.fast_forward_cache.query( req.sampling_params.regex ) @@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service): can_run_list = [] new_batch_total_tokens = 0 new_batch_input_tokens = 0 - new_batch_prefix_tokens = 0 available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service): return None if self.tp_rank == 0: + running_req = 0 if self.running_batch is None else len(self.running_batch.reqs) + hit_tokens = sum(len(x.prefix_indices) for x in can_run_list) + self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9 + self.tree_cache_metrics["hit"] += hit_tokens / 10**9 + tree_cache_hit_rate = ( + self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] + ) logger.info( f"new fill batch. #seq: {len(can_run_list)}. " - f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. " + f"#cached_token: {hit_tokens}. " f"#new_token: {new_batch_input_tokens}. " f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " - f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" + f"#running_req: {running_req}. " + f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." + ) + logger.debug( + f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " + f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " + f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. " + f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. " ) new_batch = Batch.init_new(