From 62b3812b696862588e7f88533bde5cc57e8d2acf Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 9 Apr 2024 23:27:31 +0800 Subject: [PATCH] Time cost utils (#355) --- python/sglang/backend/openai.py | 3 +- python/sglang/srt/constrained/fsm_cache.py | 1 + python/sglang/srt/server.py | 6 +- python/sglang/srt/server_args.py | 21 ++++--- python/sglang/srt/utils.py | 70 ++++++++++++---------- test/srt/model/bench_llama_low_api.py | 12 ++-- 6 files changed, 66 insertions(+), 47 deletions(-) diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index f2dd2f067..6cad2f6aa 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams try: - import tiktoken - import openai + import tiktoken except ImportError as e: openai = tiktoken = e diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 1d8fbfc67..9c467ab33 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -7,6 +7,7 @@ class FSMCache(BaseCache): super().__init__(enable=enable) from importlib.metadata import version + if version("outlines") >= "0.0.35": from transformers import AutoTokenizer diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 25de8e16c..aa3e5291b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import ( from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import handle_port_init +from sglang.srt.utils import enable_show_time_cost, handle_port_init from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse @@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer): global tokenizer_manager global chat_template_name + # start show time thread + if server_args.show_time_cost: + enable_show_time_cost() + # disable disk cache if needed if server_args.disable_disk_cache: disable_cache() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b59fd1c0c..4ed14a6be 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,13 +26,14 @@ class ServerArgs: disable_log_stats: bool = False log_stats_interval: int = 10 log_level: str = "info" + api_key: str = "" + show_time_cost: bool = False # optional modes disable_radix_cache: bool = False enable_flashinfer: bool = False disable_regex_jump_forward: bool = False disable_disk_cache: bool = False - api_key: str = "" def __post_init__(self): if self.tokenizer_path is None: @@ -181,6 +182,18 @@ class ServerArgs: default=ServerArgs.log_stats_interval, help="Log stats interval in second.", ) + parser.add_argument( + "--api-key", + type=str, + default=ServerArgs.api_key, + help="Set API Key", + ) + parser.add_argument( + "--show-time-cost", + action="store_true", + help="Show time cost of custom marks", + ) + # optional modes parser.add_argument( "--disable-radix-cache", @@ -202,12 +215,6 @@ class ServerArgs: action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) - parser.add_argument( - "--api-key", - type=str, - default=ServerArgs.api_key, - help="Set API Key", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 86680c3bb..0f7322bb6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -11,48 +11,56 @@ from typing import List, Optional import numpy as np import requests import torch -import torch.distributed as dist -is_show_cost_time = False +show_time_cost = False +time_infos = {} -def mark_cost_time(func_name): - def inner_func(func): - def time_func(*args, **kwargs): - if dist.get_rank() in [0, 1] and is_show_cost_time: - torch.cuda.synchronize() - start_time = time.time() - ans = func(*args, **kwargs) - torch.cuda.synchronize() - print(func_name, "cost time:", (time.time() - start_time) * 1000) - return ans - else: - torch.cuda.synchronize() - ans = func(*args, **kwargs) - torch.cuda.synchronize() - return ans - - return time_func - - return inner_func +def enable_show_time_cost(): + global show_time_cost + show_time_cost = True -time_mark = {} +class TimeInfo: + def __init__(self, name, interval=0.1, color=0, indent=0): + self.name = name + self.interval = interval + self.color = color + self.indent = indent + + self.acc_time = 0 + self.last_acc_time = 0 + + def check(self): + if self.acc_time - self.last_acc_time > self.interval: + self.last_acc_time = self.acc_time + return True + return False + + def pretty_print(self): + print(f"\x1b[{self.color}m", end="") + print("-" * self.indent * 2, end="") + print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m") -def mark_start(key): +def mark_start(name, interval=0.1, color=0, indent=0): + global time_infos, show_time_cost + if not show_time_cost: + return torch.cuda.synchronize() - global time_mark - time_mark[key] = time.time() - return + if time_infos.get(name, None) is None: + time_infos[name] = TimeInfo(name, interval, color, indent) + time_infos[name].acc_time -= time.time() -def mark_end(key, print_min_cost=0.0): +def mark_end(name): + global time_infos, show_time_cost + if not show_time_cost: + return torch.cuda.synchronize() - global time_mark - cost_time = (time.time() - time_mark[key]) * 1000 - if cost_time > print_min_cost: - print(f"cost {key}:", cost_time) + time_infos[name].acc_time += time.time() + if time_infos[name].check(): + time_infos[name].pretty_print() def calculate_time(show=False, min_cost_ms=0.0): diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py index 34c64cd6c..973907274 100644 --- a/test/srt/model/bench_llama_low_api.py +++ b/test/srt/model/bench_llama_low_api.py @@ -66,9 +66,9 @@ class BenchBatch: p_idx = prefix_req_idx[i // fork_num].item() n_idx = self.req_pool_indices[i].item() req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] - req_to_token[ - n_idx, prefix_len : prefix_len + extend_len - ] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len] + req_to_token[n_idx, prefix_len : prefix_len + extend_len] = ( + self.out_cache_loc[i * extend_len : (i + 1) * extend_len] + ) def update_decode(self, predict_ids, batch_size): assert predict_ids.shape[0] == batch_size @@ -81,9 +81,9 @@ class BenchBatch: self.out_cache_cont_start, self.out_cache_cont_end, ) = self.token_to_kv_pool.alloc_contiguous(batch_size) - self.req_to_token_pool.req_to_token[ - self.req_pool_indices, self.seq_lens - ] = self.out_cache_loc + self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( + self.out_cache_loc + ) self.seq_lens.add_(1)