Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)

Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-06-29 23:16:19 -07:00
committed by GitHub
parent c5131f7a2f
commit 22352d47a9
24 changed files with 626 additions and 160 deletions

View File

@@ -418,14 +418,16 @@ class Scheduler(
self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0
self.num_retracted_reqs: int = 0
self.num_paused_reqs: int = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
@@ -473,26 +475,12 @@ class Scheduler(
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()
# Init memory saver, profiler and metric stats
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
# Init metrics stats
self.init_profier()
self.init_metrics()
self.init_kv_events(server_args.kv_events_config)
@@ -526,6 +514,7 @@ class Scheduler(
]
)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
@@ -624,6 +613,21 @@ class Scheduler(
)
)
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
@@ -2107,6 +2111,18 @@ class Scheduler(
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
),
"kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
),
"cuda_graph": round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
),
"token_capacity": int(self.max_total_num_tokens),
}
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count