Simplify prometheus metrics (#1981)
Co-authored-by: Mohit Reddy <mohitreddy1996@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,6 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -255,16 +254,6 @@ class Req:
|
|||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
self.mrope_position_delta = [] # use mutable object
|
self.mrope_position_delta = [] # use mutable object
|
||||||
|
|
||||||
# Lifetime traces
|
|
||||||
# time when request is created and added to waitlist
|
|
||||||
self.created_time = None
|
|
||||||
# time when request is added to prefill batch
|
|
||||||
self.queued_time = None
|
|
||||||
# time when request is being processed
|
|
||||||
self.started_time = None
|
|
||||||
# time when request is finished
|
|
||||||
self.finished_time = None
|
|
||||||
|
|
||||||
# whether request reached finished condition
|
# whether request reached finished condition
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self.finished_reason is not None
|
return self.finished_reason is not None
|
||||||
@@ -1038,10 +1027,6 @@ class ScheduleBatch:
|
|||||||
f"#req={(len(self.reqs))})"
|
f"#req={(len(self.reqs))})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_reqs_started(self):
|
|
||||||
for req in self.reqs:
|
|
||||||
req.started_time = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelWorkerBatch:
|
class ModelWorkerBatch:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
@@ -307,7 +306,6 @@ class PrefillAdder:
|
|||||||
):
|
):
|
||||||
# Non-chunked prefill
|
# Non-chunked prefill
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
req.queued_time = time.time()
|
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
prefix_len,
|
prefix_len,
|
||||||
@@ -326,7 +324,6 @@ class PrefillAdder:
|
|||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
req.queued_time = time.time()
|
|
||||||
self.new_inflight_req = req
|
self.new_inflight_req = req
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|||||||
@@ -62,8 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|||||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
from sglang.srt.metrics.metrics_types import Stats
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
@@ -106,6 +105,7 @@ class Scheduler:
|
|||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
self.enable_overlap = server_args.enable_overlap_schedule
|
self.enable_overlap = server_args.enable_overlap_schedule
|
||||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||||
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
@@ -224,8 +224,7 @@ class Scheduler:
|
|||||||
self.forward_ct = 0
|
self.forward_ct = 0
|
||||||
self.forward_ct_decode = 0
|
self.forward_ct_decode = 0
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.last_stats_tic = time.time() # time of last stats for every iter
|
self.last_decode_stats_tic = time.time()
|
||||||
self.last_log_tic = time.time() # time of last log for print decode log
|
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
|
|
||||||
# Init chunked prefill
|
# Init chunked prefill
|
||||||
@@ -294,14 +293,15 @@ class Scheduler:
|
|||||||
],
|
],
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init metrics stats
|
# Init metrics stats
|
||||||
self.stats = Stats()
|
self.stats = SchedulerStats()
|
||||||
self.metrics_collector = PrometheusMetricsCollector(
|
if self.enable_metrics:
|
||||||
|
self.metrics_collector = SchedulerMetricsCollector(
|
||||||
labels={
|
labels={
|
||||||
"model_name": self.server_args.served_model_name,
|
"model_name": self.server_args.served_model_name,
|
||||||
# TODO: Add lora name/path in the future,
|
# TODO: Add lora name/path in the future,
|
||||||
},
|
},
|
||||||
max_model_len=self.max_total_num_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def watchdog_thread(self):
|
def watchdog_thread(self):
|
||||||
@@ -350,11 +350,6 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
# log stats
|
|
||||||
if self.is_generation and self.server_args.enable_metrics:
|
|
||||||
stats = self.get_stats(batch)
|
|
||||||
self.log_stats(stats)
|
|
||||||
self.last_stats_tic = time.time()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -493,7 +488,6 @@ class Scheduler:
|
|||||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
req.created_time = time.time()
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def handle_embedding_request(
|
def handle_embedding_request(
|
||||||
@@ -518,25 +512,68 @@ class Scheduler:
|
|||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def print_decode_stats(self):
|
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
|
||||||
|
if isinstance(self.tree_cache, RadixCache):
|
||||||
|
self.tree_cache_metrics["total"] += (
|
||||||
|
adder.log_input_tokens + adder.log_hit_tokens
|
||||||
|
) / 10**9
|
||||||
|
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
||||||
|
tree_cache_hit_rate = (
|
||||||
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tree_cache_hit_rate = 0.0
|
||||||
|
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
|
|
||||||
self.num_generated_tokens = 0
|
logger.info(
|
||||||
self.last_log_tic = time.time()
|
f"Prefill batch. "
|
||||||
# set system stats
|
f"#new-seq: {len(can_run_list)}, "
|
||||||
|
f"#new-token: {adder.log_input_tokens}, "
|
||||||
|
f"#cached-token: {adder.log_hit_tokens}, "
|
||||||
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
|
f"#running-req: {running_bs}, "
|
||||||
|
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.stats.num_running_reqs = running_bs
|
||||||
|
self.stats.num_used_tokens = num_used
|
||||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
||||||
|
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
|
||||||
|
self.stats.cache_hit_rate = tree_cache_hit_rate
|
||||||
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
|
||||||
|
def log_decode_stats(self):
|
||||||
|
num_used = self.max_total_num_tokens - (
|
||||||
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
|
)
|
||||||
|
gen_throughput = self.num_generated_tokens / (
|
||||||
|
time.time() - self.last_decode_stats_tic
|
||||||
|
)
|
||||||
|
self.num_generated_tokens = 0
|
||||||
|
self.last_decode_stats_tic = time.time()
|
||||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Decode batch. "
|
f"Decode batch. "
|
||||||
f"#running-req: {num_running_reqs}, "
|
f"#running-req: {num_running_reqs}, "
|
||||||
f"#token: {num_used}, "
|
f"#token: {num_used}, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"gen throughput (token/s): {throughput:.2f}, "
|
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||||
f"#queue-req: {len(self.waiting_queue)}"
|
f"#queue-req: {len(self.waiting_queue)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.stats.num_running_reqs = num_running_reqs
|
||||||
|
self.stats.num_used_tokens = num_used
|
||||||
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||||
|
self.stats.gen_throughput = gen_throughput
|
||||||
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
|
||||||
def check_memory(self):
|
def check_memory(self):
|
||||||
available_size = (
|
available_size = (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
@@ -612,7 +649,6 @@ class Scheduler:
|
|||||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
||||||
|
|
||||||
# Prefill policy
|
# Prefill policy
|
||||||
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.running_batch,
|
self.running_batch,
|
||||||
@@ -620,7 +656,7 @@ class Scheduler:
|
|||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.chunked_prefill_size,
|
self.chunked_prefill_size,
|
||||||
num_mixed_running,
|
running_bs if self.is_mixed_chunk else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_inflight = self.being_chunked_req is not None
|
has_inflight = self.being_chunked_req is not None
|
||||||
@@ -677,47 +713,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
if isinstance(self.tree_cache, RadixCache):
|
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
|
||||||
self.tree_cache_metrics["total"] += (
|
|
||||||
adder.log_input_tokens + adder.log_hit_tokens
|
|
||||||
) / 10**9
|
|
||||||
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
|
||||||
tree_cache_hit_rate = (
|
|
||||||
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tree_cache_hit_rate = 0.0
|
|
||||||
|
|
||||||
num_used = self.max_total_num_tokens - (
|
|
||||||
self.token_to_kv_pool.available_size()
|
|
||||||
+ self.tree_cache.evictable_size()
|
|
||||||
)
|
|
||||||
# set system stats
|
|
||||||
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
|
|
||||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
|
||||||
|
|
||||||
if num_mixed_running > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Prefill batch"
|
|
||||||
f"(mixed #running-req: {num_mixed_running}). "
|
|
||||||
f"#new-seq: {len(can_run_list)}, "
|
|
||||||
f"#new-token: {adder.log_input_tokens}, "
|
|
||||||
f"#cached-token: {adder.log_hit_tokens}, "
|
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
|
||||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Prefill batch. "
|
|
||||||
f"#new-seq: {len(can_run_list)}, "
|
|
||||||
f"#new-token: {adder.log_input_tokens}, "
|
|
||||||
f"#cached-token: {adder.log_hit_tokens}, "
|
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
|
||||||
f"#running-req: {running_bs}, "
|
|
||||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a new batch
|
# Create a new batch
|
||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
@@ -789,7 +785,6 @@ class Scheduler:
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||||
batch.mark_reqs_started()
|
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
@@ -810,94 +805,6 @@ class Scheduler:
|
|||||||
ret = embeddings, model_worker_batch.bid
|
ret = embeddings, model_worker_batch.bid
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_stats(self, batch: ScheduleBatch):
|
|
||||||
# TODO: get stats for chunked prefill
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
# system stats
|
|
||||||
# Scheduler State
|
|
||||||
new_seq: int = 0
|
|
||||||
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
|
|
||||||
num_waiting_req = len(self.waiting_queue)
|
|
||||||
# Cache State
|
|
||||||
cache_hit_rate: float = 0.0
|
|
||||||
token_usage: float = 0.0
|
|
||||||
|
|
||||||
# set stats from prefill
|
|
||||||
if self.stats is not None:
|
|
||||||
# new_seq=self.stats.new_seq
|
|
||||||
cache_hit_rate = self.stats.cache_hit_rate
|
|
||||||
token_usage = self.stats.token_usage
|
|
||||||
# Iteration stats
|
|
||||||
num_prompt_tokens_iter = 0
|
|
||||||
num_generation_tokens_iter = 0
|
|
||||||
time_to_first_tokens_iter: List[float] = []
|
|
||||||
time_per_output_tokens_iter: List[float] = []
|
|
||||||
|
|
||||||
# Request stats
|
|
||||||
# Decode
|
|
||||||
gen_throughput: float = 0.0
|
|
||||||
# Latency
|
|
||||||
time_e2e_requests: List[float] = []
|
|
||||||
time_waiting_requests: List[float] = []
|
|
||||||
# Metadata
|
|
||||||
num_prompt_tokens_requests: List[int] = []
|
|
||||||
num_generation_tokens_requests: List[int] = []
|
|
||||||
finished_reason_requests: List[str] = []
|
|
||||||
|
|
||||||
# _, next_token_ids, _ = result
|
|
||||||
if batch is not None:
|
|
||||||
num_generation_tokens_iter = len(batch.output_ids)
|
|
||||||
gen_throughput = round(
|
|
||||||
num_generation_tokens_iter / (now - self.last_stats_tic), 2
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, req in enumerate(batch.reqs):
|
|
||||||
# NOTE: Batch forward mode is extend befor start decode,
|
|
||||||
if batch.forward_mode.is_extend():
|
|
||||||
num_prompt_tokens_iter = len(batch.input_ids) + sum(
|
|
||||||
batch.prefix_lens
|
|
||||||
)
|
|
||||||
time_to_first_tokens_iter.append(now - req.started_time)
|
|
||||||
else:
|
|
||||||
time_per_output_tokens_iter.append(now - self.last_stats_tic)
|
|
||||||
|
|
||||||
if req.finished():
|
|
||||||
time_e2e_requests.append(now - req.created_time)
|
|
||||||
time_waiting_requests.append(req.queued_time - req.created_time)
|
|
||||||
num_prompt_tokens_requests.append(len(req.origin_input_ids))
|
|
||||||
num_generation_tokens_requests.append(len(req.output_ids))
|
|
||||||
finished_reason_requests.append(
|
|
||||||
req.finished_reason.to_json()
|
|
||||||
if req.finished_reason is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return Stats(
|
|
||||||
new_seq=new_seq,
|
|
||||||
num_running_req=num_running_req,
|
|
||||||
num_waiting_req=num_waiting_req,
|
|
||||||
cache_hit_rate=cache_hit_rate,
|
|
||||||
token_usage=token_usage,
|
|
||||||
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
|
||||||
num_generation_tokens_iter=num_generation_tokens_iter,
|
|
||||||
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
|
||||||
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
|
||||||
gen_throughput=gen_throughput,
|
|
||||||
time_e2e_requests=time_e2e_requests,
|
|
||||||
time_waiting_requests=time_waiting_requests,
|
|
||||||
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
|
||||||
num_generation_tokens_requests=num_generation_tokens_requests,
|
|
||||||
finished_reason_requests=finished_reason_requests,
|
|
||||||
context_len=self.model_config.context_len,
|
|
||||||
max_total_num_tokens=self.max_total_num_tokens,
|
|
||||||
max_prefill_tokens=self.max_prefill_tokens,
|
|
||||||
max_running_requests=self.max_running_requests,
|
|
||||||
)
|
|
||||||
|
|
||||||
def log_stats(self, stats: Stats):
|
|
||||||
self.metrics_collector.log_stats(stats)
|
|
||||||
|
|
||||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result)
|
||||||
@@ -1035,7 +942,7 @@ class Scheduler:
|
|||||||
self.tp_rank == 0
|
self.tp_rank == 0
|
||||||
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
||||||
):
|
):
|
||||||
self.print_decode_stats()
|
self.log_decode_stats()
|
||||||
|
|
||||||
def add_logprob_return_values(
|
def add_logprob_return_values(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -52,6 +53,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightReqInput,
|
UpdateWeightReqInput,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
||||||
@@ -69,6 +71,10 @@ class ReqState:
|
|||||||
finished: bool
|
finished: bool
|
||||||
event: asyncio.Event
|
event: asyncio.Event
|
||||||
|
|
||||||
|
# For metrics
|
||||||
|
created_time: float
|
||||||
|
first_token_time: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
"""TokenizerManager is a process that tokenizes the text."""
|
"""TokenizerManager is a process that tokenizes the text."""
|
||||||
@@ -80,6 +86,7 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
@@ -142,11 +149,22 @@ class TokenizerManager:
|
|||||||
# Others
|
# Others
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.metrics_collector = TokenizerMetricsCollector(
|
||||||
|
labels={
|
||||||
|
"model_name": self.server_args.served_model_name,
|
||||||
|
# TODO: Add lora name/path in the future,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
|
created_time = time.time()
|
||||||
|
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
@@ -164,10 +182,12 @@ class TokenizerManager:
|
|||||||
if is_single:
|
if is_single:
|
||||||
tokenized_obj = await self._tokenize_one_request(obj)
|
tokenized_obj = await self._tokenize_one_request(obj)
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
async for response in self._wait_one_response(obj, request):
|
async for response in self._wait_one_response(obj, request, created_time):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
async for response in self._handle_batch_request(obj, request):
|
async for response in self._handle_batch_request(
|
||||||
|
obj, request, created_time
|
||||||
|
):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
async def _tokenize_one_request(
|
async def _tokenize_one_request(
|
||||||
@@ -231,10 +251,11 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
|
created_time: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""Wait for the response of one request."""
|
"""Wait for the response of one request."""
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event, created_time=created_time)
|
||||||
self.rid_to_state[obj.rid] = state
|
self.rid_to_state[obj.rid] = state
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -272,6 +293,7 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
|
created_time: Optional[float] = None,
|
||||||
):
|
):
|
||||||
batch_size = obj.batch_size
|
batch_size = obj.batch_size
|
||||||
|
|
||||||
@@ -283,7 +305,9 @@ class TokenizerManager:
|
|||||||
tmp_obj = obj[i]
|
tmp_obj = obj[i]
|
||||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
generators.append(self._wait_one_response(tmp_obj, request))
|
generators.append(
|
||||||
|
self._wait_one_response(tmp_obj, request, created_time)
|
||||||
|
)
|
||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||||
@@ -303,7 +327,9 @@ class TokenizerManager:
|
|||||||
tokenized_obj.sampling_params.max_new_tokens = 0
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
||||||
tokenized_obj.stream = False
|
tokenized_obj.stream = False
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
await self._wait_one_response(tmp_obj, request).__anext__()
|
await self._wait_one_response(
|
||||||
|
tmp_obj, request, created_time
|
||||||
|
).__anext__()
|
||||||
|
|
||||||
# Expand requests, assign new rids for them, and send them
|
# Expand requests, assign new rids for them, and send them
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@@ -312,7 +338,9 @@ class TokenizerManager:
|
|||||||
tokenized_obj = copy.copy(tokenized_objs[i])
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
||||||
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
generators.append(self._wait_one_response(tmp_obj, request))
|
generators.append(
|
||||||
|
self._wait_one_response(tmp_obj, request, created_time)
|
||||||
|
)
|
||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
|
|
||||||
# Wait for all requests
|
# Wait for all requests
|
||||||
@@ -524,6 +552,34 @@ class TokenizerManager:
|
|||||||
state.finished = recv_obj.finished_reason[i] is not None
|
state.finished = recv_obj.finished_reason[i] is not None
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
|
if self.enable_metrics:
|
||||||
|
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
||||||
|
|
||||||
|
if state.first_token_time is None:
|
||||||
|
state.first_token_time = time.time()
|
||||||
|
self.metrics_collector.observe_time_to_first_token(
|
||||||
|
state.first_token_time - state.created_time
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if completion_tokens >= 2:
|
||||||
|
self.metrics_collector.observe_time_per_output_token(
|
||||||
|
(time.time() - state.first_token_time)
|
||||||
|
/ (completion_tokens - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.finished:
|
||||||
|
self.metrics_collector.inc_prompt_tokens(
|
||||||
|
recv_obj.meta_info[i]["prompt_tokens"]
|
||||||
|
)
|
||||||
|
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
||||||
|
self.metrics_collector.observe_e2e_request_latency(
|
||||||
|
time.time() - state.created_time
|
||||||
|
)
|
||||||
|
if completion_tokens >= 1:
|
||||||
|
self.metrics_collector.observe_time_per_output_token(
|
||||||
|
(time.time() - state.created_time) / completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self,
|
self,
|
||||||
ret: dict,
|
ret: dict,
|
||||||
|
|||||||
211
python/sglang/srt/metrics/collector.py
Normal file
211
python/sglang/srt/metrics/collector.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Utilities for Prometheus Metrics Collection."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SchedulerStats:
|
||||||
|
num_running_reqs: int = 0
|
||||||
|
num_used_tokens: int = 0
|
||||||
|
token_usage: float = 0.0
|
||||||
|
gen_throughput: float = 0.0
|
||||||
|
num_queue_reqs: int = 0
|
||||||
|
cache_hit_rate: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMetricsCollector:
|
||||||
|
|
||||||
|
def __init__(self, labels: Dict[str, str]) -> None:
|
||||||
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||||
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
self.num_running_reqs = Gauge(
|
||||||
|
name="sglang:num_running_reqs",
|
||||||
|
documentation="The number of running requests",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_used_tokens = Gauge(
|
||||||
|
name="sglang:num_used_tokens",
|
||||||
|
documentation="The number of used tokens",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.token_usage = Gauge(
|
||||||
|
name="sglang:token_usage",
|
||||||
|
documentation="The token usage",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="mostrecent",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gen_throughput = Gauge(
|
||||||
|
name="sglang:gen_throughput",
|
||||||
|
documentation="The generate throughput (token/s)",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_queue_reqs = Gauge(
|
||||||
|
name="sglang:num_queue_reqs",
|
||||||
|
documentation="The number of requests in the waiting queue",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_hit_rate = Gauge(
|
||||||
|
name="sglang:cache_hit_rate",
|
||||||
|
documentation="The cache hit rate",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
multiprocess_mode="mostrecent",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||||
|
# Convenience function for logging to gauge.
|
||||||
|
gauge.labels(**self.labels).set(data)
|
||||||
|
|
||||||
|
def log_stats(self, stats: SchedulerStats) -> None:
|
||||||
|
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||||
|
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
||||||
|
self._log_gauge(self.token_usage, stats.token_usage)
|
||||||
|
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
||||||
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
||||||
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerMetricsCollector:
|
||||||
|
def __init__(self, labels: Dict[str, str]) -> None:
|
||||||
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
self.prompt_tokens_total = Counter(
|
||||||
|
name="sglang:prompt_tokens_total",
|
||||||
|
documentation="Number of prefill tokens processed.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.generation_tokens_total = Counter(
|
||||||
|
name="sglang:generation_tokens_total",
|
||||||
|
documentation="Number of generation tokens processed.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_time_to_first_token = Histogram(
|
||||||
|
name="sglang:time_to_first_token_seconds",
|
||||||
|
documentation="Histogram of time to first token in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.001,
|
||||||
|
0.005,
|
||||||
|
0.01,
|
||||||
|
0.02,
|
||||||
|
0.04,
|
||||||
|
0.06,
|
||||||
|
0.08,
|
||||||
|
0.1,
|
||||||
|
0.25,
|
||||||
|
0.5,
|
||||||
|
0.75,
|
||||||
|
1.0,
|
||||||
|
2.5,
|
||||||
|
5.0,
|
||||||
|
7.5,
|
||||||
|
10.0,
|
||||||
|
15.0,
|
||||||
|
20.0,
|
||||||
|
25.0,
|
||||||
|
30.0,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_time_per_output_token = Histogram(
|
||||||
|
name="sglang:time_per_output_token_seconds",
|
||||||
|
documentation="Histogram of time per output token in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.005,
|
||||||
|
0.01,
|
||||||
|
0.015,
|
||||||
|
0.02,
|
||||||
|
0.025,
|
||||||
|
0.03,
|
||||||
|
0.04,
|
||||||
|
0.05,
|
||||||
|
0.075,
|
||||||
|
0.1,
|
||||||
|
0.15,
|
||||||
|
0.2,
|
||||||
|
0.3,
|
||||||
|
0.4,
|
||||||
|
0.5,
|
||||||
|
0.75,
|
||||||
|
1.0,
|
||||||
|
2.5,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_e2e_request_latency = Histogram(
|
||||||
|
name="sglang:e2e_request_latency_seconds",
|
||||||
|
documentation="Histogram of End-to-end request latency in seconds",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.8,
|
||||||
|
1.0,
|
||||||
|
1.5,
|
||||||
|
2.0,
|
||||||
|
2.5,
|
||||||
|
5.0,
|
||||||
|
10.0,
|
||||||
|
15.0,
|
||||||
|
20.0,
|
||||||
|
30.0,
|
||||||
|
40.0,
|
||||||
|
50.0,
|
||||||
|
60.0,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||||
|
histogram.labels(**self.labels).observe(data)
|
||||||
|
|
||||||
|
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
||||||
|
# Convenience function for logging to counter.
|
||||||
|
counter.labels(**self.labels).inc(data)
|
||||||
|
|
||||||
|
def inc_prompt_tokens(self, value: int):
|
||||||
|
self._log_counter(self.prompt_tokens_total, value)
|
||||||
|
|
||||||
|
def inc_generation_tokens(self, value: int):
|
||||||
|
self._log_counter(self.generation_tokens_total, value)
|
||||||
|
|
||||||
|
def observe_time_to_first_token(self, value: Union[float, int]):
|
||||||
|
self._log_histogram(self.histogram_time_to_first_token, value)
|
||||||
|
|
||||||
|
def observe_time_per_output_token(self, value: Union[float, int]):
|
||||||
|
self._log_histogram(self.histogram_time_per_output_token, value)
|
||||||
|
|
||||||
|
def observe_e2e_request_latency(self, value: Union[float, int]):
|
||||||
|
self._log_histogram(self.histogram_e2e_request_latency, value)
|
||||||
108
python/sglang/srt/metrics/func_timer.py
Normal file
108
python/sglang/srt/metrics/func_timer.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Records the latency of some functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
enable_metrics = False
|
||||||
|
|
||||||
|
|
||||||
|
def enable_func_timer():
|
||||||
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||||
|
from prometheus_client import Histogram
|
||||||
|
|
||||||
|
global enable_metrics, FUNC_LATENCY
|
||||||
|
enable_metrics = True
|
||||||
|
|
||||||
|
FUNC_LATENCY = Histogram(
|
||||||
|
"sglang:func_latency_seconds",
|
||||||
|
"Function latency in seconds",
|
||||||
|
# captures latency in range [50ms - ~50s]
|
||||||
|
buckets=exponential_buckets(start=0.05, width=1.5, length=18),
|
||||||
|
labelnames=["name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FUNC_LATENCY = None
|
||||||
|
|
||||||
|
|
||||||
|
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
|
||||||
|
buckets = []
|
||||||
|
for i in range(length):
|
||||||
|
buckets.append(start * (width**i))
|
||||||
|
return buckets
|
||||||
|
|
||||||
|
|
||||||
|
def time_func_latency(
|
||||||
|
func: Callable = None, name: Optional[str] = None
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""
|
||||||
|
A decorator to observe the latency of a function's execution. Supports both sync and async functions.
|
||||||
|
|
||||||
|
NOTE: We use our own implementation of a timer decorator since prometheus_client does not support async
|
||||||
|
context manager yet.
|
||||||
|
|
||||||
|
Overhead: The overhead introduced here in case of an async function could likely be because of `await` introduced
|
||||||
|
which will return in another coroutine object creation and under heavy load could see longer wall time
|
||||||
|
(scheduling delays due to introduction of another awaitable).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def measure(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
nonlocal name
|
||||||
|
|
||||||
|
name = name or func.__name__
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
if not enable_metrics:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
metric = FUNC_LATENCY
|
||||||
|
start = time.monotonic()
|
||||||
|
ret = func(*args, **kwargs)
|
||||||
|
if isinstance(ret, asyncio.Future) or asyncio.iscoroutine(ret):
|
||||||
|
try:
|
||||||
|
ret = await ret
|
||||||
|
finally:
|
||||||
|
metric.labels(name=name).observe(time.monotonic() - start)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
if not enable_metrics:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
metric = FUNC_LATENCY
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
ret = func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
metric.labels(name=name).observe(time.monotonic() - start)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
return async_wrapper
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
if func:
|
||||||
|
return measure(func)
|
||||||
|
else:
|
||||||
|
return measure
|
||||||
@@ -1,388 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""Utilities for Prometheus Metrics Collection."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Counter as CollectionsCounter
|
|
||||||
from typing import Dict, List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from prometheus_client import Counter, Gauge, Histogram
|
|
||||||
|
|
||||||
from sglang.srt.metrics.metrics_types import Stats
|
|
||||||
|
|
||||||
|
|
||||||
class Metrics:
|
|
||||||
"""
|
|
||||||
SGLang Metrics
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, labelnames: List[str], max_model_len):
|
|
||||||
|
|
||||||
# Configuration Stats
|
|
||||||
self.max_total_num_tokens = Gauge(
|
|
||||||
name="sglang:max_total_num_tokens",
|
|
||||||
documentation="Maximum total number of tokens",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="min",
|
|
||||||
) # static across processes
|
|
||||||
|
|
||||||
self.max_prefill_tokens = Gauge(
|
|
||||||
name="sglang:max_prefill_tokens",
|
|
||||||
documentation="Maximum prefill tokens",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="min",
|
|
||||||
) # static across processes
|
|
||||||
|
|
||||||
self.max_running_requests = Gauge(
|
|
||||||
name="sglang:max_running_requests",
|
|
||||||
documentation="Maximum running requests",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="min",
|
|
||||||
) # static across processes
|
|
||||||
|
|
||||||
self.context_len = Gauge(
|
|
||||||
name="sglang:context_len",
|
|
||||||
documentation="Context length",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="min",
|
|
||||||
) # static across processes
|
|
||||||
# Decode Stats
|
|
||||||
self.num_running_sys = Gauge(
|
|
||||||
name="sglang:num_requests_running",
|
|
||||||
documentation="Number of requests currently running on GPU",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
self.num_waiting_sys = Gauge(
|
|
||||||
name="sglang:num_requests_waiting",
|
|
||||||
documentation="Number of requests waiting to be processed.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
self.gen_throughput = Gauge(
|
|
||||||
name="sglang:gen_throughput",
|
|
||||||
documentation="Gen token throughput (token/s)",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
self.token_usage = Gauge(
|
|
||||||
name="sglang:token_usage",
|
|
||||||
documentation="Total token usage",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
# System Stats
|
|
||||||
# KV Cache Usage in %
|
|
||||||
# self.gpu_cache_usage_sys = Gauge(
|
|
||||||
# "gpu_cache_usage_perc",
|
|
||||||
# "GPU KV-cache usage. 1 means 100 percent usage.",
|
|
||||||
# labelnames=labelnames,
|
|
||||||
# multiprocess_mode="sum")
|
|
||||||
|
|
||||||
self.new_seq = Gauge(
|
|
||||||
name="sglang:new_seq",
|
|
||||||
documentation="Number of new sequences",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
self.new_token = Gauge(
|
|
||||||
name="sglang:new_token",
|
|
||||||
documentation="Number of new token",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
# Prefix caching block hit rate
|
|
||||||
self.cached_token = Gauge(
|
|
||||||
name="sglang:cached_token",
|
|
||||||
documentation="Number of cached token",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
self.cache_hit_rate = Gauge(
|
|
||||||
name="sglang:cache_hit_rate",
|
|
||||||
documentation="Cache hit rate",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
self.queue_req = Gauge(
|
|
||||||
name="sglang:queue_req",
|
|
||||||
documentation="Number of queued requests",
|
|
||||||
labelnames=labelnames,
|
|
||||||
multiprocess_mode="sum",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iteration stats
|
|
||||||
self.counter_prompt_tokens = Counter(
|
|
||||||
name="sglang:prompt_tokens_total",
|
|
||||||
documentation="Number of prefill tokens processed.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
)
|
|
||||||
self.counter_generation_tokens = Counter(
|
|
||||||
name="sglang:generation_tokens_total",
|
|
||||||
documentation="Number of generation tokens processed.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
)
|
|
||||||
self.histogram_time_to_first_token = Histogram(
|
|
||||||
name="sglang:time_to_first_token_seconds",
|
|
||||||
documentation="Histogram of time to first token in seconds.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=[
|
|
||||||
0.001,
|
|
||||||
0.005,
|
|
||||||
0.01,
|
|
||||||
0.02,
|
|
||||||
0.04,
|
|
||||||
0.06,
|
|
||||||
0.08,
|
|
||||||
0.1,
|
|
||||||
0.25,
|
|
||||||
0.5,
|
|
||||||
0.75,
|
|
||||||
1.0,
|
|
||||||
2.5,
|
|
||||||
5.0,
|
|
||||||
7.5,
|
|
||||||
10.0,
|
|
||||||
15.0,
|
|
||||||
20.0,
|
|
||||||
25.0,
|
|
||||||
30.0,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.histogram_time_per_output_token = Histogram(
|
|
||||||
name="sglang:time_per_output_token_seconds",
|
|
||||||
documentation="Histogram of time per output token in seconds.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=[
|
|
||||||
0.005,
|
|
||||||
0.01,
|
|
||||||
0.015,
|
|
||||||
0.02,
|
|
||||||
0.025,
|
|
||||||
0.03,
|
|
||||||
0.04,
|
|
||||||
0.05,
|
|
||||||
0.075,
|
|
||||||
0.1,
|
|
||||||
0.15,
|
|
||||||
0.2,
|
|
||||||
0.3,
|
|
||||||
0.4,
|
|
||||||
0.5,
|
|
||||||
0.75,
|
|
||||||
1.0,
|
|
||||||
2.5,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Request Stats
|
|
||||||
# Metadata
|
|
||||||
self.num_prompt_tokens_requests = Histogram(
|
|
||||||
name="sglang:request_prompt_tokens",
|
|
||||||
documentation="Number of prefill tokens processed",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=build_1_2_5_buckets(max_model_len),
|
|
||||||
)
|
|
||||||
self.num_generation_tokens_requests = Histogram(
|
|
||||||
name="sglang:request_generation_tokens",
|
|
||||||
documentation="Number of generation tokens processed.",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=build_1_2_5_buckets(max_model_len),
|
|
||||||
)
|
|
||||||
self.finished_reason_requests = Counter(
|
|
||||||
name="sglang:request_success_total",
|
|
||||||
documentation="Count of successfully processed requests.",
|
|
||||||
labelnames=labelnames + ["finished_reason"],
|
|
||||||
)
|
|
||||||
self.histogram_time_e2e_requests = Histogram(
|
|
||||||
name="sglang:e2e_request_latency_seconds",
|
|
||||||
documentation="Histogram of End-to-end request latency in seconds",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=[
|
|
||||||
0.3,
|
|
||||||
0.5,
|
|
||||||
0.8,
|
|
||||||
1.0,
|
|
||||||
1.5,
|
|
||||||
2.0,
|
|
||||||
2.5,
|
|
||||||
5.0,
|
|
||||||
10.0,
|
|
||||||
15.0,
|
|
||||||
20.0,
|
|
||||||
30.0,
|
|
||||||
40.0,
|
|
||||||
50.0,
|
|
||||||
60.0,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.histogram_time_waiting_requests = Histogram(
|
|
||||||
name="sglang:waiting_request_latency_seconds",
|
|
||||||
documentation="Histogram of request waiting time in seconds",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=[
|
|
||||||
0.3,
|
|
||||||
0.5,
|
|
||||||
0.8,
|
|
||||||
1.0,
|
|
||||||
1.5,
|
|
||||||
2.0,
|
|
||||||
2.5,
|
|
||||||
5.0,
|
|
||||||
10.0,
|
|
||||||
15.0,
|
|
||||||
20.0,
|
|
||||||
30.0,
|
|
||||||
40.0,
|
|
||||||
50.0,
|
|
||||||
60.0,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.histogram_time_decode_requests = Histogram(
|
|
||||||
name="sglang:decode_request_latency_seconds",
|
|
||||||
documentation="Histogram of request decoding time in seconds",
|
|
||||||
labelnames=labelnames,
|
|
||||||
buckets=[
|
|
||||||
0.3,
|
|
||||||
0.5,
|
|
||||||
0.8,
|
|
||||||
1.0,
|
|
||||||
1.5,
|
|
||||||
2.0,
|
|
||||||
2.5,
|
|
||||||
5.0,
|
|
||||||
10.0,
|
|
||||||
15.0,
|
|
||||||
20.0,
|
|
||||||
30.0,
|
|
||||||
40.0,
|
|
||||||
50.0,
|
|
||||||
60.0,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MetricsCollector(ABC):
|
|
||||||
"""
|
|
||||||
SGLang Metrics Collector
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def log_stats(self, stats: Stats) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PrometheusMetricsCollector(MetricsCollector):
|
|
||||||
"""
|
|
||||||
SGLang Metrics Collector
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, labels: Dict[str, str], max_model_len: int) -> None:
|
|
||||||
self.labels = labels
|
|
||||||
self.metrics = Metrics(
|
|
||||||
labelnames=list(labels.keys()), max_model_len=max_model_len
|
|
||||||
)
|
|
||||||
|
|
||||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
|
||||||
# Convenience function for logging to gauge.
|
|
||||||
gauge.labels(**self.labels).set(data)
|
|
||||||
|
|
||||||
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
|
||||||
# Convenience function for logging to counter.
|
|
||||||
counter.labels(**self.labels).inc(data)
|
|
||||||
|
|
||||||
def _log_counter_labels(
|
|
||||||
self, counter, data: CollectionsCounter, label_key: str
|
|
||||||
) -> None:
|
|
||||||
# Convenience function for collection counter of labels.
|
|
||||||
for label, count in data.items():
|
|
||||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
|
||||||
|
|
||||||
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
|
|
||||||
# Convenience function for logging list to histogram.
|
|
||||||
for datum in data:
|
|
||||||
histogram.labels(**self.labels).observe(datum)
|
|
||||||
|
|
||||||
def log_stats(self, stats: Stats) -> None:
|
|
||||||
self._log_gauge(self.metrics.max_total_num_tokens, stats.max_total_num_tokens)
|
|
||||||
self._log_gauge(self.metrics.max_prefill_tokens, stats.max_prefill_tokens)
|
|
||||||
self._log_gauge(self.metrics.max_running_requests, stats.max_running_requests)
|
|
||||||
self._log_gauge(self.metrics.context_len, stats.context_len)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.num_prompt_tokens_requests, stats.num_prompt_tokens_requests
|
|
||||||
)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.num_generation_tokens_requests,
|
|
||||||
stats.num_generation_tokens_requests,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._log_counter(
|
|
||||||
self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter
|
|
||||||
)
|
|
||||||
self._log_counter(
|
|
||||||
self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
|
|
||||||
)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
|
|
||||||
)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.histogram_time_per_output_token,
|
|
||||||
stats.time_per_output_tokens_iter,
|
|
||||||
)
|
|
||||||
|
|
||||||
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
|
|
||||||
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)
|
|
||||||
self._log_gauge(self.metrics.num_waiting_sys, stats.num_waiting_req)
|
|
||||||
self._log_gauge(self.metrics.gen_throughput, stats.gen_throughput)
|
|
||||||
self._log_gauge(self.metrics.token_usage, stats.token_usage)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.histogram_time_e2e_requests, stats.time_e2e_requests
|
|
||||||
)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.histogram_time_waiting_requests, stats.time_waiting_requests
|
|
||||||
)
|
|
||||||
self._log_histogram(
|
|
||||||
self.metrics.histogram_time_decode_requests, stats.time_decode_requests
|
|
||||||
)
|
|
||||||
self._log_gauge(self.metrics.new_seq, stats.new_seq)
|
|
||||||
self._log_gauge(self.metrics.new_token, stats.new_token)
|
|
||||||
self._log_gauge(self.metrics.cached_token, stats.cached_token)
|
|
||||||
self._log_gauge(self.metrics.cache_hit_rate, stats.cache_hit_rate)
|
|
||||||
self._log_gauge(self.metrics.queue_req, stats.queue_req)
|
|
||||||
|
|
||||||
|
|
||||||
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
|
||||||
"""
|
|
||||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
|
||||||
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> build_1_2_5_buckets(100)
|
|
||||||
[1, 2, 5, 10, 20, 50, 100]
|
|
||||||
"""
|
|
||||||
mantissa_lst = [1, 2, 5]
|
|
||||||
exponent = 0
|
|
||||||
buckets: List[int] = []
|
|
||||||
while True:
|
|
||||||
for m in mantissa_lst:
|
|
||||||
value = m * 10**exponent
|
|
||||||
if value <= max_value:
|
|
||||||
buckets.append(value)
|
|
||||||
else:
|
|
||||||
return buckets
|
|
||||||
exponent += 1
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""Metrics Types"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Stats:
|
|
||||||
# config
|
|
||||||
max_total_num_tokens: int = 0
|
|
||||||
max_prefill_tokens: int = 0
|
|
||||||
max_running_requests: int = 0
|
|
||||||
context_len: int = 0
|
|
||||||
# request stats
|
|
||||||
num_prompt_tokens_requests: List[int] = field(default_factory=list)
|
|
||||||
num_generation_tokens_requests: List[int] = field(default_factory=list)
|
|
||||||
finished_reason_requests: List[str] = field(default_factory=list)
|
|
||||||
# decode stats
|
|
||||||
num_running_req: int = 0
|
|
||||||
num_waiting_req: int = 0
|
|
||||||
gen_throughput: float = 0.0
|
|
||||||
waiting_queue: int = 0
|
|
||||||
time_e2e_requests: List[float] = field(default_factory=list)
|
|
||||||
time_waiting_requests: List[float] = field(default_factory=list)
|
|
||||||
time_decode_requests: List[float] = field(default_factory=list)
|
|
||||||
# system stats
|
|
||||||
token_usage: float = 0.0
|
|
||||||
new_seq: int = 0
|
|
||||||
new_token: int = 0
|
|
||||||
cached_token: int = 0
|
|
||||||
cache_hit_rate: float = 0.0
|
|
||||||
running_req: int = 0
|
|
||||||
queue_req: int = 0
|
|
||||||
|
|
||||||
# Iteration stats (should have _iter suffix)
|
|
||||||
num_prompt_tokens_iter: int = 0
|
|
||||||
num_generation_tokens_iter: int = 0
|
|
||||||
time_to_first_tokens_iter: List[float] = field(default_factory=list)
|
|
||||||
time_per_output_tokens_iter: List[float] = field(default_factory=list)
|
|
||||||
@@ -56,6 +56,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
|
||||||
from sglang.srt.openai_api.adapter import (
|
from sglang.srt.openai_api.adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
v1_batches,
|
v1_batches,
|
||||||
@@ -196,6 +197,7 @@ async def get_memory_pool_size():
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights")
|
@app.post("/update_weights")
|
||||||
|
@time_func_latency
|
||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
"""Update the weights inplace without re-launching the server."""
|
"""Update the weights inplace without re-launching the server."""
|
||||||
success, message = await tokenizer_manager.update_weights(obj, request)
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
||||||
@@ -212,7 +214,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
@time_func_latency
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
"""Handle a generate request."""
|
"""Handle a generate request."""
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
@@ -245,10 +247,12 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||||
app.post("/generate")(generate_request)
|
app.post("/generate")(generate_request)
|
||||||
app.put("/generate")(generate_request)
|
app.put("/generate")(generate_request)
|
||||||
|
|
||||||
|
|
||||||
|
@time_func_latency
|
||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle an embedding request."""
|
"""Handle an embedding request."""
|
||||||
try:
|
try:
|
||||||
@@ -264,6 +268,7 @@ app.post("/encode")(encode_request)
|
|||||||
app.put("/encode")(encode_request)
|
app.put("/encode")(encode_request)
|
||||||
|
|
||||||
|
|
||||||
|
@time_func_latency
|
||||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||||
try:
|
try:
|
||||||
@@ -283,16 +288,19 @@ app.put("/classify")(classify_request)
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
|
@time_func_latency
|
||||||
async def openai_v1_completions(raw_request: Request):
|
async def openai_v1_completions(raw_request: Request):
|
||||||
return await v1_completions(tokenizer_manager, raw_request)
|
return await v1_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
|
@time_func_latency
|
||||||
async def openai_v1_chat_completions(raw_request: Request):
|
async def openai_v1_chat_completions(raw_request: Request):
|
||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
||||||
|
@time_func_latency
|
||||||
async def openai_v1_embeddings(raw_request: Request):
|
async def openai_v1_embeddings(raw_request: Request):
|
||||||
response = await v1_embeddings(tokenizer_manager, raw_request)
|
response = await v1_embeddings(tokenizer_manager, raw_request)
|
||||||
return response
|
return response
|
||||||
@@ -455,6 +463,7 @@ def launch_server(
|
|||||||
# add prometheus middleware
|
# add prometheus middleware
|
||||||
if server_args.enable_metrics:
|
if server_args.enable_metrics:
|
||||||
add_prometheus_middleware(app)
|
add_prometheus_middleware(app)
|
||||||
|
enable_func_timer()
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
@@ -492,6 +501,10 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
|
|
||||||
|
# Set prometheus env vars
|
||||||
|
if server_args.enable_metrics:
|
||||||
|
set_prometheus_multiproc_dir()
|
||||||
|
|
||||||
# Set ulimit
|
# Set ulimit
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
|
|
||||||
@@ -510,10 +523,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set prometheus env vars
|
|
||||||
if server_args.enable_metrics:
|
|
||||||
set_prometheus_multiproc_dir()
|
|
||||||
|
|
||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -781,6 +781,7 @@ def set_prometheus_multiproc_dir():
|
|||||||
|
|
||||||
|
|
||||||
def add_prometheus_middleware(app):
|
def add_prometheus_middleware(app):
|
||||||
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||||
|
|
||||||
registry = CollectorRegistry()
|
registry = CollectorRegistry()
|
||||||
|
|||||||
@@ -22,23 +22,41 @@ class TestEnableMetrics(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Make a request to generate some metrics
|
# Make some requests to generate some metrics
|
||||||
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for _ in response.iter_lines(decode_unicode=False):
|
||||||
|
pass
|
||||||
|
|
||||||
# Get metrics
|
# Get metrics
|
||||||
metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics")
|
metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics")
|
||||||
self.assertEqual(metrics_response.status_code, 200)
|
self.assertEqual(metrics_response.status_code, 200)
|
||||||
metrics_content = metrics_response.text
|
metrics_content = metrics_response.text
|
||||||
|
|
||||||
print(f"{metrics_content=}")
|
print(f"metrics_content=\n{metrics_content}")
|
||||||
|
|
||||||
# Verify essential metrics are present
|
# Verify essential metrics are present
|
||||||
essential_metrics = [
|
essential_metrics = [
|
||||||
|
"sglang:num_running_reqs",
|
||||||
|
"sglang:token_usage",
|
||||||
|
"sglang:gen_throughput",
|
||||||
|
"sglang:cache_hit_rate",
|
||||||
|
"sglang:func_latency_seconds",
|
||||||
"sglang:prompt_tokens_total",
|
"sglang:prompt_tokens_total",
|
||||||
"sglang:generation_tokens_total",
|
"sglang:generation_tokens_total",
|
||||||
"sglang:max_total_num_tokens",
|
|
||||||
"sglang:context_len",
|
|
||||||
"sglang:time_to_first_token_seconds",
|
"sglang:time_to_first_token_seconds",
|
||||||
"sglang:time_per_output_token_seconds",
|
"sglang:time_per_output_token_seconds",
|
||||||
"sglang:e2e_request_latency_seconds",
|
"sglang:e2e_request_latency_seconds",
|
||||||
@@ -50,6 +68,7 @@ class TestEnableMetrics(unittest.TestCase):
|
|||||||
# Verify model name label is present and correct
|
# Verify model name label is present and correct
|
||||||
expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
self.assertIn(f'model_name="{expected_model_name}"', metrics_content)
|
self.assertIn(f'model_name="{expected_model_name}"', metrics_content)
|
||||||
|
|
||||||
# Verify metrics have values (not empty)
|
# Verify metrics have values (not empty)
|
||||||
self.assertIn("_sum{", metrics_content)
|
self.assertIn("_sum{", metrics_content)
|
||||||
self.assertIn("_count{", metrics_content)
|
self.assertIn("_count{", metrics_content)
|
||||||
|
|||||||
Reference in New Issue
Block a user