Clean up server_args.py (#7037)

This commit is contained in:
Lianmin Zheng
2025-06-10 05:34:29 -07:00
committed by GitHub
parent 019851d099
commit 6406408a70
7 changed files with 394 additions and 331 deletions

View File

@@ -17,12 +17,14 @@ from __future__ import annotations
import bisect
import inspect
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import tqdm
from torch.profiler import ProfilerActivity, profile
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_device_memory_capacity,
rank0_log,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -207,6 +212,9 @@ class CudaGraphRunner:
model_runner.server_args.enable_two_batch_overlap
)
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size
@@ -339,44 +347,67 @@ class CudaGraphRunner:
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
def capture(self):
def capture(self) -> None:
profile_context = empty_context()
if self.enable_profile_cuda_graph:
profile_context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
)
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with profile_context as prof:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
for i, bs in enumerate(capture_range):
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
# Save gemlite cache after each capture
save_gemlite_cache()
# Save gemlite cache after each capture
save_gemlite_cache()
if self.enable_profile_cuda_graph:
log_message = (
"Sorted by CUDA Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cuda_time_total", row_limit=10
)
+ "\n\nSorted by CPU Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
)
logger.info(log_message)
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
@@ -443,7 +474,7 @@ class CudaGraphRunner:
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
seq_lens_sum=seq_lens.sum().item(),
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,