Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)

Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-06-29 23:16:19 -07:00
committed by GitHub
parent c5131f7a2f
commit 22352d47a9
24 changed files with 626 additions and 160 deletions

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument("--log-requests", action="store_true")
parser.add_argument("--log-requests-level", type=int, default=2)
parser.add_argument("--log-requests-level", type=int, default=3)
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)

View File

@@ -516,9 +516,6 @@ class EmbeddingReqInput:
# For cross-encoder requests
is_cross_encoder_request: bool = False
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
@@ -572,6 +569,9 @@ class EmbeddingReqInput:
self.rid = uuid.uuid4().hex
return self.rid
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def __getitem__(self, i):
if self.is_cross_encoder_request:
return EmbeddingReqInput(

View File

@@ -2,12 +2,15 @@
Multi-modality utils
"""
import hashlib
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)

View File

@@ -3,7 +3,6 @@ import importlib
import inspect
import logging
import pkgutil
from functools import lru_cache
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.server_args import ServerArgs

View File

@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy
import dataclasses
import hashlib
import logging
import threading
from enum import Enum, auto
@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"max_micro_batch_size",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_acc",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
@@ -180,7 +178,9 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or others.
One MultimodalDataItem contains all inputs for one modality.
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields last.
"""
@@ -232,53 +232,7 @@ class MultimodalDataItem:
"""
Set the pad value after first hashing the data
"""
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x
for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
from sglang.srt.managers.mm_utils import hash_feature
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)

View File

@@ -418,14 +418,16 @@ class Scheduler(
self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0
self.num_retracted_reqs: int = 0
self.num_paused_reqs: int = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
@@ -473,26 +475,12 @@ class Scheduler(
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()
# Init memory saver, profiler and metric stats
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
# Init metrics stats
self.init_profier()
self.init_metrics()
self.init_kv_events(server_args.kv_events_config)
@@ -526,6 +514,7 @@ class Scheduler(
]
)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
@@ -624,6 +613,21 @@ class Scheduler(
)
)
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
@@ -2107,6 +2111,18 @@ class Scheduler(
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
),
"kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
),
"cuda_graph": round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
),
"token_capacity": int(self.max_total_num_tokens),
}
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count

View File

@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
stream_interval = (
req.sampling_params.stream_interval or self.stream_interval
)
should_output = len(req.output_ids) % stream_interval == 0
should_output = (
len(req.output_ids) % stream_interval == 1
if not self.model_config.is_multimodal_gen
and stream_interval > 1
else len(req.output_ids) % stream_interval == 0
)
else:
should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
and not self.model_config.is_multimodal_gen
if not self.model_config.is_multimodal_gen
else False
)
if should_output:

View File

@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.multimodal_processor import (
get_dummy_processor,
get_mm_processor,
import_processors,
)
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -187,6 +183,8 @@ class TokenizerManager:
if server_args.preferred_sampling_params
else None
)
self.crash_dump_folder = server_args.crash_dump_folder
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Init inter-process communication
context = zmq.asyncio.Context(2)
@@ -251,10 +249,11 @@ class TokenizerManager:
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.crash_dump_request_list: deque[Tuple] = deque()
self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
self.asyncio_tasks = set()
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
@@ -266,14 +265,14 @@ class TokenizerManager:
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
@@ -324,7 +323,6 @@ class TokenizerManager:
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -484,7 +482,7 @@ class TokenizerManager:
token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async(
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data,
input_text=input_text or input_ids,
request_obj=obj,
@@ -547,6 +545,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature."
)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
) -> None:
if any(id >= vocab_size for id in input_ids):
raise ValueError(
f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
)
def _create_tokenized_object(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -1096,12 +1102,36 @@ class TokenizerManager:
"image_data",
"audio_data",
"lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1:
max_length = 2048
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30
else:
raise ValueError(
@@ -1118,6 +1148,8 @@ class TokenizerManager:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
if obj.crash_dump_folder is not None:
self.crash_dump_folder = obj.crash_dump_folder
logging.info(f"Config logging: {obj=}")
self.log_request_metadata = self.get_log_request_metadata()
@@ -1166,6 +1198,52 @@ class TokenizerManager:
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
def dump_requests_before_crash(self):
if self.crash_dump_performed:
logger.info(
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
if not self.crash_dump_folder:
return
data_to_dump = []
if self.crash_dump_request_list:
data_to_dump.extend(self.crash_dump_request_list)
# Add unfinished requests from rid_to_state
unfinished_requests = []
for rid, state in self.rid_to_state.items():
if not state.finished:
unfinished_requests.append(
(state.obj, {}, state.created_time, time.time())
)
if unfinished_requests:
data_to_dump.extend(unfinished_requests)
if not data_to_dump:
return
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Include server_args in the dump
data_to_dump_with_server_args = {
"server_args": self.server_args,
"requests": data_to_dump,
}
with open(filename, "wb") as f:
pickle.dump(data_to_dump_with_server_args, f)
logger.error(
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
)
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(5)
@@ -1175,11 +1253,12 @@ class TokenizerManager:
remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
# if health check failed, exit immediately
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
remain_num_req,
)
self.dump_requests_before_crash()
break
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
@@ -1196,6 +1275,7 @@ class TokenizerManager:
if remain_num_req > 0:
await asyncio.sleep(5)
else:
self.dump_requests_before_crash()
break
kill_process_tree(os.getpid(), include_parent=True)
@@ -1273,16 +1353,7 @@ class TokenizerManager:
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchMultimodalOut):
if isinstance(recv_obj.outputs[i], str):
out_dict = {
"text": recv_obj.outputs[i],
"meta_info": meta_info,
}
else:
out_dict = {
"outputs": json.dumps(recv_obj.outputs[i]),
"meta_info": meta_info,
}
raise NotImplementedError("BatchMultimodalOut not implemented")
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
@@ -1306,6 +1377,8 @@ class TokenizerManager:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self.dump_requests(state, out_dict)
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
self.record_request_for_crash_dump(state, out_dict)
def convert_logprob_style(
self,
@@ -1317,6 +1390,9 @@ class TokenizerManager:
recv_obj: BatchStrOut,
recv_obj_index: int,
):
if recv_obj.input_token_logprobs_val is None:
return
if len(recv_obj.input_token_logprobs_val) > 0:
state.input_token_logprobs_val.extend(
recv_obj.input_token_logprobs_val[recv_obj_index]
@@ -1436,7 +1512,10 @@ class TokenizerManager:
else 0
)
if state.first_token_time == 0.0:
if (
state.first_token_time == 0.0
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token(
@@ -1484,14 +1563,31 @@ class TokenizerManager:
to_dump = self.dump_request_list
self.dump_request_list = []
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": to_dump,
}
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump, f)
pickle.dump(to_dump_with_server_args, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
current_time = time.time()
self.crash_dump_request_list.append(
(state.obj, out_dict, state.created_time, current_time)
)
# Remove requests older than 5 minutes based on finish time
while (
self.crash_dump_request_list
and current_time - self.crash_dump_request_list[0][3] >= 300
):
self.crash_dump_request_list.popleft()
def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid, None)
@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
except Exception:
traceback = get_exception_traceback()
logger.error(f"TokenizerManager hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
@@ -1632,6 +1730,7 @@ class SignalHandler:
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
self.tokenizer_manager.dump_requests_before_crash()
kill_process_tree(os.getpid())