Simplify prometheus metrics (#1981)

Co-authored-by: Mohit Reddy <mohitreddy1996@users.noreply.github.com>
This commit is contained in:
Lianmin Zheng
2024-11-10 04:39:32 -08:00
committed by GitHub
parent ed53ac84b4
commit 1929c06762
11 changed files with 483 additions and 632 deletions

View File

@@ -62,8 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
@@ -106,6 +105,7 @@ class Scheduler:
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
# Init inter-process communication
context = zmq.Context(2)
@@ -224,8 +224,7 @@ class Scheduler:
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
# Init chunked prefill
@@ -294,15 +293,16 @@ class Scheduler:
],
with_stack=True,
)
# Init metrics stats
self.stats = Stats()
self.metrics_collector = PrometheusMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
max_model_len=self.max_total_num_tokens,
)
self.stats = SchedulerStats()
if self.enable_metrics:
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
@@ -350,11 +350,6 @@ class Scheduler:
else:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()
self.last_batch = batch
@@ -493,7 +488,6 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1,
)
req.created_time = time.time()
self.waiting_queue.append(req)
def handle_embedding_request(
@@ -518,25 +512,68 @@ class Scheduler:
self.waiting_queue.append(req)
def print_decode_stats(self):
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
if self.enable_metrics:
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
self.stats.cache_hit_rate = tree_cache_hit_rate
self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
gen_throughput = self.num_generated_tokens / (
time.time() - self.last_decode_stats_tic
)
self.num_generated_tokens = 0
self.last_log_tic = time.time()
# set system stats
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.metrics_collector.log_stats(self.stats)
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -612,7 +649,6 @@ class Scheduler:
prefix_computed = self.policy.calc_priority(self.waiting_queue)
# Prefill policy
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder(
self.tree_cache,
self.running_batch,
@@ -620,7 +656,7 @@ class Scheduler:
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
num_mixed_running,
running_bs if self.is_mixed_chunk else 0,
)
has_inflight = self.being_chunked_req is not None
@@ -677,47 +713,7 @@ class Scheduler:
# Print stats
if self.tp_rank == 0:
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
if num_mixed_running > 0:
logger.info(
f"Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
else:
logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
# Create a new batch
new_batch = ScheduleBatch.init_new(
@@ -789,7 +785,6 @@ class Scheduler:
if self.is_generation:
model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
@@ -810,94 +805,6 @@ class Scheduler:
ret = embeddings, model_worker_batch.bid
return ret
def get_stats(self, batch: ScheduleBatch):
# TODO: get stats for chunked prefill
now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0
# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate = self.stats.cache_hit_rate
token_usage = self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []
# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(
num_generation_tokens_iter / (now - self.last_stats_tic), 2
)
for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter = len(batch.input_ids) + sum(
batch.prefix_lens
)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now - self.last_stats_tic)
if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None
)
return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)
def log_stats(self, stats: Stats):
self.metrics_collector.log_stats(stats)
def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
@@ -1035,7 +942,7 @@ class Scheduler:
self.tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.print_decode_stats()
self.log_decode_stats()
def add_logprob_return_values(
self,