Time cost utils (#355)
This commit is contained in:
@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor
|
|||||||
from sglang.lang.ir import SglSamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
import tiktoken
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
openai = tiktoken = e
|
openai = tiktoken = e
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
|
|||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
if version("outlines") >= "0.0.35":
|
if version("outlines") >= "0.0.35":
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import (
|
|||||||
from sglang.srt.managers.router.manager import start_router_process
|
from sglang.srt.managers.router.manager import start_router_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import handle_port_init
|
from sglang.srt.utils import enable_show_time_cost, handle_port_init
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
|
# start show time thread
|
||||||
|
if server_args.show_time_cost:
|
||||||
|
enable_show_time_cost()
|
||||||
|
|
||||||
# disable disk cache if needed
|
# disable disk cache if needed
|
||||||
if server_args.disable_disk_cache:
|
if server_args.disable_disk_cache:
|
||||||
disable_cache()
|
disable_cache()
|
||||||
|
|||||||
@@ -26,13 +26,14 @@ class ServerArgs:
|
|||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
log_stats_interval: int = 10
|
log_stats_interval: int = 10
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
|
api_key: str = ""
|
||||||
|
show_time_cost: bool = False
|
||||||
|
|
||||||
# optional modes
|
# optional modes
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
enable_flashinfer: bool = False
|
enable_flashinfer: bool = False
|
||||||
disable_regex_jump_forward: bool = False
|
disable_regex_jump_forward: bool = False
|
||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
api_key: str = ""
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -181,6 +182,18 @@ class ServerArgs:
|
|||||||
default=ServerArgs.log_stats_interval,
|
default=ServerArgs.log_stats_interval,
|
||||||
help="Log stats interval in second.",
|
help="Log stats interval in second.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.api_key,
|
||||||
|
help="Set API Key",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-time-cost",
|
||||||
|
action="store_true",
|
||||||
|
help="Show time cost of custom marks",
|
||||||
|
)
|
||||||
|
|
||||||
# optional modes
|
# optional modes
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-radix-cache",
|
"--disable-radix-cache",
|
||||||
@@ -202,12 +215,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--api-key",
|
|
||||||
type=str,
|
|
||||||
default=ServerArgs.api_key,
|
|
||||||
help="Set API Key",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
@@ -11,48 +11,56 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
is_show_cost_time = False
|
show_time_cost = False
|
||||||
|
time_infos = {}
|
||||||
|
|
||||||
|
|
||||||
def mark_cost_time(func_name):
|
def enable_show_time_cost():
|
||||||
def inner_func(func):
|
global show_time_cost
|
||||||
def time_func(*args, **kwargs):
|
show_time_cost = True
|
||||||
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
ans = func(*args, **kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
|
||||||
return ans
|
|
||||||
else:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
ans = func(*args, **kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return ans
|
|
||||||
|
|
||||||
return time_func
|
|
||||||
|
|
||||||
return inner_func
|
|
||||||
|
|
||||||
|
|
||||||
time_mark = {}
|
class TimeInfo:
|
||||||
|
def __init__(self, name, interval=0.1, color=0, indent=0):
|
||||||
|
self.name = name
|
||||||
|
self.interval = interval
|
||||||
|
self.color = color
|
||||||
|
self.indent = indent
|
||||||
|
|
||||||
|
self.acc_time = 0
|
||||||
|
self.last_acc_time = 0
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.acc_time - self.last_acc_time > self.interval:
|
||||||
|
self.last_acc_time = self.acc_time
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
print(f"\x1b[{self.color}m", end="")
|
||||||
|
print("-" * self.indent * 2, end="")
|
||||||
|
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
|
||||||
|
|
||||||
|
|
||||||
def mark_start(key):
|
def mark_start(name, interval=0.1, color=0, indent=0):
|
||||||
|
global time_infos, show_time_cost
|
||||||
|
if not show_time_cost:
|
||||||
|
return
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
global time_mark
|
if time_infos.get(name, None) is None:
|
||||||
time_mark[key] = time.time()
|
time_infos[name] = TimeInfo(name, interval, color, indent)
|
||||||
return
|
time_infos[name].acc_time -= time.time()
|
||||||
|
|
||||||
|
|
||||||
def mark_end(key, print_min_cost=0.0):
|
def mark_end(name):
|
||||||
|
global time_infos, show_time_cost
|
||||||
|
if not show_time_cost:
|
||||||
|
return
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
global time_mark
|
time_infos[name].acc_time += time.time()
|
||||||
cost_time = (time.time() - time_mark[key]) * 1000
|
if time_infos[name].check():
|
||||||
if cost_time > print_min_cost:
|
time_infos[name].pretty_print()
|
||||||
print(f"cost {key}:", cost_time)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_time(show=False, min_cost_ms=0.0):
|
def calculate_time(show=False, min_cost_ms=0.0):
|
||||||
|
|||||||
@@ -66,9 +66,9 @@ class BenchBatch:
|
|||||||
p_idx = prefix_req_idx[i // fork_num].item()
|
p_idx = prefix_req_idx[i // fork_num].item()
|
||||||
n_idx = self.req_pool_indices[i].item()
|
n_idx = self.req_pool_indices[i].item()
|
||||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||||
req_to_token[
|
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
|
||||||
n_idx, prefix_len : prefix_len + extend_len
|
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||||
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
)
|
||||||
|
|
||||||
def update_decode(self, predict_ids, batch_size):
|
def update_decode(self, predict_ids, batch_size):
|
||||||
assert predict_ids.shape[0] == batch_size
|
assert predict_ids.shape[0] == batch_size
|
||||||
@@ -81,9 +81,9 @@ class BenchBatch:
|
|||||||
self.out_cache_cont_start,
|
self.out_cache_cont_start,
|
||||||
self.out_cache_cont_end,
|
self.out_cache_cont_end,
|
||||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||||
self.req_pool_indices, self.seq_lens
|
self.out_cache_loc
|
||||||
] = self.out_cache_loc
|
)
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user