Clean up server_args.py (#7037)
This commit is contained in:
@@ -17,12 +17,14 @@ from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.profiler import ProfilerActivity, profile
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_device_memory_capacity,
|
||||
rank0_log,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
@@ -207,6 +212,9 @@ class CudaGraphRunner:
|
||||
model_runner.server_args.enable_two_batch_overlap
|
||||
)
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.enable_profile_cuda_graph = (
|
||||
model_runner.server_args.enable_profile_cuda_graph
|
||||
)
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.pp_size = model_runner.server_args.pp_size
|
||||
@@ -339,44 +347,67 @@ class CudaGraphRunner:
|
||||
|
||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||
|
||||
def capture(self):
|
||||
def capture(self) -> None:
|
||||
profile_context = empty_context()
|
||||
if self.enable_profile_cuda_graph:
|
||||
profile_context = profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
)
|
||||
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
||||
)
|
||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||
capture_range = (
|
||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else reversed(self.capture_bs)
|
||||
)
|
||||
for bs in capture_range:
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
capture_range.set_description(
|
||||
f"Capturing batches ({avail_mem=:.2f} GB)"
|
||||
)
|
||||
with profile_context as prof:
|
||||
self.stream = graph_capture_context.stream
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||
capture_range = (
|
||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else reversed(self.capture_bs)
|
||||
)
|
||||
for i, bs in enumerate(capture_range):
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
capture_range.set_description(
|
||||
f"Capturing batches ({avail_mem=:.2f} GB)"
|
||||
)
|
||||
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
num_tokens=bs * self.num_tokens_per_bs,
|
||||
tp_group=self.model_runner.tp_group,
|
||||
) as forward:
|
||||
(
|
||||
graph,
|
||||
output_buffers,
|
||||
) = self.capture_one_batch_size(bs, forward)
|
||||
self.graphs[bs] = graph
|
||||
self.output_buffers[bs] = output_buffers
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
num_tokens=bs * self.num_tokens_per_bs,
|
||||
tp_group=self.model_runner.tp_group,
|
||||
) as forward:
|
||||
(
|
||||
graph,
|
||||
output_buffers,
|
||||
) = self.capture_one_batch_size(bs, forward)
|
||||
self.graphs[bs] = graph
|
||||
self.output_buffers[bs] = output_buffers
|
||||
|
||||
# Save gemlite cache after each capture
|
||||
save_gemlite_cache()
|
||||
# Save gemlite cache after each capture
|
||||
save_gemlite_cache()
|
||||
|
||||
if self.enable_profile_cuda_graph:
|
||||
log_message = (
|
||||
"Sorted by CUDA Time:\n"
|
||||
+ prof.key_averages(group_by_input_shape=True).table(
|
||||
sort_by="cuda_time_total", row_limit=10
|
||||
)
|
||||
+ "\n\nSorted by CPU Time:\n"
|
||||
+ prof.key_averages(group_by_input_shape=True).table(
|
||||
sort_by="cpu_time_total", row_limit=10
|
||||
)
|
||||
)
|
||||
logger.info(log_message)
|
||||
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
@@ -443,7 +474,7 @@ class CudaGraphRunner:
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens.sum(),
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
|
||||
Reference in New Issue
Block a user