Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)
Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||
parser.add_argument("--log-requests", action="store_true")
|
||||
parser.add_argument("--log-requests-level", type=int, default=2)
|
||||
parser.add_argument("--log-requests-level", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||
)
|
||||
|
||||
@@ -516,9 +516,6 @@ class EmbeddingReqInput:
|
||||
# For cross-encoder requests
|
||||
is_cross_encoder_request: bool = False
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
# at least one of text, input_ids, or image should be provided
|
||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||
@@ -572,6 +569,9 @@ class EmbeddingReqInput:
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if self.is_cross_encoder_request:
|
||||
return EmbeddingReqInput(
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
Multi-modality utils
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
|
||||
# Convert valid pairs to tensor
|
||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
||||
return valid_pairs_tensor
|
||||
|
||||
|
||||
def data_hash(data) -> int:
|
||||
hash_bytes = hashlib.sha256(data).digest()[:8]
|
||||
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
||||
|
||||
|
||||
def tensor_hash(tensor_list) -> int:
|
||||
"""
|
||||
hash a tensor or a tensor list
|
||||
"""
|
||||
tensor = tensor_list
|
||||
if isinstance(tensor_list, list):
|
||||
tensor_list = flatten_nested_list(tensor_list)
|
||||
tensor_list = [
|
||||
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
if tensor.is_cuda:
|
||||
return gpu_tensor_hash(tensor)
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
if tensor.is_cuda:
|
||||
# TODO: improve this
|
||||
tensor_cpu = tensor.cpu()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
|
||||
mv = memoryview(tensor_cpu.numpy())
|
||||
return data_hash(mv.tobytes())
|
||||
|
||||
|
||||
def hash_feature(f):
|
||||
if isinstance(f, list):
|
||||
if isinstance(f[0], torch.Tensor):
|
||||
return tensor_hash(f)
|
||||
return data_hash(tuple(flatten_nested_list(f)))
|
||||
elif isinstance(f, np.ndarray):
|
||||
arr = np.ascontiguousarray(f)
|
||||
arr_bytes = arr.tobytes()
|
||||
return data_hash(arr_bytes)
|
||||
elif isinstance(f, torch.Tensor):
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
|
||||
@@ -3,7 +3,6 @@ import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"max_micro_batch_size",
|
||||
"disable_shared_experts_fusion",
|
||||
"sampling_backend",
|
||||
"speculative_accept_threshold_acc",
|
||||
"speculative_accept_threshold_single",
|
||||
"speculative_accept_threshold_acc",
|
||||
"torchao_config",
|
||||
"triton_attention_reduce_in_fp32",
|
||||
"num_reserved_decode_tokens",
|
||||
@@ -180,7 +178,9 @@ class Modality(Enum):
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
"""
|
||||
A single multimodal data, from a single image/video/audio or others.
|
||||
One MultimodalDataItem contains all inputs for one modality.
|
||||
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
||||
One for images and one for audio.
|
||||
|
||||
We put the common fields first and the model-specific fields last.
|
||||
"""
|
||||
@@ -232,53 +232,7 @@ class MultimodalDataItem:
|
||||
"""
|
||||
Set the pad value after first hashing the data
|
||||
"""
|
||||
|
||||
def data_hash(data) -> int:
|
||||
hash_bytes = hashlib.sha256(data).digest()[:8]
|
||||
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
||||
|
||||
def tensor_hash(tensor_list) -> int:
|
||||
"""
|
||||
hash a tensor or a tensor list
|
||||
"""
|
||||
tensor = tensor_list
|
||||
if isinstance(tensor_list, list):
|
||||
tensor_list = flatten_nested_list(tensor_list)
|
||||
tensor_list = [
|
||||
x.flatten() if isinstance(x, torch.Tensor) else x
|
||||
for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
if tensor.is_cuda:
|
||||
return gpu_tensor_hash(tensor)
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
if tensor.is_cuda:
|
||||
# TODO: improve this
|
||||
tensor_cpu = tensor.cpu()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
|
||||
mv = memoryview(tensor_cpu.numpy())
|
||||
return data_hash(mv.tobytes())
|
||||
|
||||
def hash_feature(f):
|
||||
if isinstance(f, list):
|
||||
if isinstance(f[0], torch.Tensor):
|
||||
return tensor_hash(f)
|
||||
return data_hash(tuple(flatten_nested_list(f)))
|
||||
elif isinstance(f, np.ndarray):
|
||||
arr = np.ascontiguousarray(f)
|
||||
arr_bytes = arr.tobytes()
|
||||
return data_hash(arr_bytes)
|
||||
elif isinstance(f, torch.Tensor):
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
from sglang.srt.managers.mm_utils import hash_feature
|
||||
|
||||
if self.precomputed_features is not None:
|
||||
self.hash = hash_feature(self.precomputed_features)
|
||||
|
||||
@@ -418,14 +418,16 @@ class Scheduler(
|
||||
self.last_decode_stats_tic = time.perf_counter()
|
||||
self.last_prefill_stats_tic = time.perf_counter()
|
||||
self.return_health_check_ct = 0
|
||||
self.num_retracted_reqs: int = 0
|
||||
self.num_paused_reqs: int = 0
|
||||
self.kv_transfer_speed_gb_s: float = 0.0
|
||||
self.kv_transfer_latency_ms: float = 0.0
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.current_stream.synchronize = lambda: None # No-op for CPU
|
||||
self.forward_sleep_time = None
|
||||
|
||||
# Init session info
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
if self.chunked_prefill_size <= 0: # -1 means disable
|
||||
@@ -473,26 +475,12 @@ class Scheduler(
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
# Init memory saver, profiler and metric stats
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Init profiler
|
||||
self.torch_profiler = None
|
||||
self.torch_profiler_output_dir: Optional[str] = None
|
||||
self.profiler_activities: Optional[List[str]] = None
|
||||
self.profile_id: Optional[str] = None
|
||||
self.profiler_target_forward_ct: Optional[int] = None
|
||||
self.profiler_target_prefill_ct: Optional[int] = None
|
||||
self.profiler_target_decode_ct: Optional[int] = None
|
||||
self.profiler_prefill_ct: Optional[int] = None
|
||||
self.profiler_decode_ct: Optional[int] = None
|
||||
self.profile_by_stage: bool = False
|
||||
self.profile_steps: Optional[int] = None
|
||||
self.profile_in_progress: bool = False
|
||||
self.rpd_profiler = None
|
||||
|
||||
# Init metrics stats
|
||||
self.init_profier()
|
||||
self.init_metrics()
|
||||
self.init_kv_events(server_args.kv_events_config)
|
||||
|
||||
@@ -526,6 +514,7 @@ class Scheduler(
|
||||
]
|
||||
)
|
||||
|
||||
# Init disaggregation
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
@@ -624,6 +613,21 @@ class Scheduler(
|
||||
)
|
||||
)
|
||||
|
||||
def init_profier(self):
|
||||
self.torch_profiler = None
|
||||
self.torch_profiler_output_dir: Optional[str] = None
|
||||
self.profiler_activities: Optional[List[str]] = None
|
||||
self.profile_id: Optional[str] = None
|
||||
self.profiler_target_forward_ct: Optional[int] = None
|
||||
self.profiler_target_prefill_ct: Optional[int] = None
|
||||
self.profiler_target_decode_ct: Optional[int] = None
|
||||
self.profiler_prefill_ct: Optional[int] = None
|
||||
self.profiler_decode_ct: Optional[int] = None
|
||||
self.profile_by_stage: bool = False
|
||||
self.profile_steps: Optional[int] = None
|
||||
self.profile_in_progress: bool = False
|
||||
self.rpd_profiler = None
|
||||
|
||||
def init_metrics(self):
|
||||
self.last_gen_throughput: float = 0.0
|
||||
self.last_input_throughput: float = 0.0
|
||||
@@ -2107,6 +2111,18 @@ class Scheduler(
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = dict(global_server_args_dict)
|
||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||
ret["memory_usage"] = {
|
||||
"weight": round(
|
||||
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
|
||||
),
|
||||
"kvcache": round(
|
||||
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
|
||||
),
|
||||
"cuda_graph": round(
|
||||
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
|
||||
),
|
||||
"token_capacity": int(self.max_total_num_tokens),
|
||||
}
|
||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||
ret["avg_spec_accept_length"] = (
|
||||
self.cum_spec_accept_length / self.cum_spec_accept_count
|
||||
|
||||
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
|
||||
stream_interval = (
|
||||
req.sampling_params.stream_interval or self.stream_interval
|
||||
)
|
||||
should_output = len(req.output_ids) % stream_interval == 0
|
||||
should_output = (
|
||||
len(req.output_ids) % stream_interval == 1
|
||||
if not self.model_config.is_multimodal_gen
|
||||
and stream_interval > 1
|
||||
else len(req.output_ids) % stream_interval == 0
|
||||
)
|
||||
else:
|
||||
should_output = (
|
||||
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
||||
and not self.model_config.is_multimodal_gen
|
||||
if not self.model_config.is_multimodal_gen
|
||||
else False
|
||||
)
|
||||
|
||||
if should_output:
|
||||
|
||||
@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
get_dummy_processor,
|
||||
get_mm_processor,
|
||||
import_processors,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -187,6 +183,8 @@ class TokenizerManager:
|
||||
if server_args.preferred_sampling_params
|
||||
else None
|
||||
)
|
||||
self.crash_dump_folder = server_args.crash_dump_folder
|
||||
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -251,10 +249,11 @@ class TokenizerManager:
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
self.crash_dump_request_list: deque[Tuple] = deque()
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
self.asyncio_tasks = set()
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
self.max_req_input_len = None
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
@@ -266,14 +265,14 @@ class TokenizerManager:
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
@@ -324,7 +323,6 @@ class TokenizerManager:
|
||||
self.profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -484,7 +482,7 @@ class TokenizerManager:
|
||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
@@ -547,6 +545,14 @@ class TokenizerManager:
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
|
||||
def _validate_input_ids_in_vocab(
|
||||
self, input_ids: List[int], vocab_size: int
|
||||
) -> None:
|
||||
if any(id >= vocab_size for id in input_ids):
|
||||
raise ValueError(
|
||||
f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
|
||||
)
|
||||
|
||||
def _create_tokenized_object(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -1096,12 +1102,36 @@ class TokenizerManager:
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
"sampling_params",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(["text", "output_ids", "embedding"])
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 2048
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 2:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 3:
|
||||
max_length = 1 << 30
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -1118,6 +1148,8 @@ class TokenizerManager:
|
||||
self.dump_requests_folder = obj.dump_requests_folder
|
||||
if obj.dump_requests_threshold is not None:
|
||||
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||
if obj.crash_dump_folder is not None:
|
||||
self.crash_dump_folder = obj.crash_dump_folder
|
||||
logging.info(f"Config logging: {obj=}")
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
|
||||
@@ -1166,6 +1198,52 @@ class TokenizerManager:
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
def dump_requests_before_crash(self):
|
||||
if self.crash_dump_performed:
|
||||
logger.info(
|
||||
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
|
||||
)
|
||||
return
|
||||
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
||||
self.crash_dump_performed = True
|
||||
if not self.crash_dump_folder:
|
||||
return
|
||||
|
||||
data_to_dump = []
|
||||
if self.crash_dump_request_list:
|
||||
data_to_dump.extend(self.crash_dump_request_list)
|
||||
|
||||
# Add unfinished requests from rid_to_state
|
||||
unfinished_requests = []
|
||||
for rid, state in self.rid_to_state.items():
|
||||
if not state.finished:
|
||||
unfinished_requests.append(
|
||||
(state.obj, {}, state.created_time, time.time())
|
||||
)
|
||||
if unfinished_requests:
|
||||
data_to_dump.extend(unfinished_requests)
|
||||
|
||||
if not data_to_dump:
|
||||
return
|
||||
|
||||
filename = os.path.join(
|
||||
self.crash_dump_folder,
|
||||
os.getenv("HOSTNAME", None),
|
||||
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
# Include server_args in the dump
|
||||
data_to_dump_with_server_args = {
|
||||
"server_args": self.server_args,
|
||||
"requests": data_to_dump,
|
||||
}
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(data_to_dump_with_server_args, f)
|
||||
logger.error(
|
||||
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(5)
|
||||
@@ -1175,11 +1253,12 @@ class TokenizerManager:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
|
||||
if self.health_check_failed:
|
||||
# if health check failed, exit immediately
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
remain_num_req,
|
||||
)
|
||||
self.dump_requests_before_crash()
|
||||
break
|
||||
|
||||
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
||||
@@ -1196,6 +1275,7 @@ class TokenizerManager:
|
||||
if remain_num_req > 0:
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
self.dump_requests_before_crash()
|
||||
break
|
||||
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
@@ -1273,16 +1353,7 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
if isinstance(recv_obj.outputs[i], str):
|
||||
out_dict = {
|
||||
"text": recv_obj.outputs[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
else:
|
||||
out_dict = {
|
||||
"outputs": json.dumps(recv_obj.outputs[i]),
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
raise NotImplementedError("BatchMultimodalOut not implemented")
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -1306,6 +1377,8 @@ class TokenizerManager:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||
self.dump_requests(state, out_dict)
|
||||
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
|
||||
self.record_request_for_crash_dump(state, out_dict)
|
||||
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
@@ -1317,6 +1390,9 @@ class TokenizerManager:
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj_index: int,
|
||||
):
|
||||
if recv_obj.input_token_logprobs_val is None:
|
||||
return
|
||||
|
||||
if len(recv_obj.input_token_logprobs_val) > 0:
|
||||
state.input_token_logprobs_val.extend(
|
||||
recv_obj.input_token_logprobs_val[recv_obj_index]
|
||||
@@ -1436,7 +1512,10 @@ class TokenizerManager:
|
||||
else 0
|
||||
)
|
||||
|
||||
if state.first_token_time == 0.0:
|
||||
if (
|
||||
state.first_token_time == 0.0
|
||||
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
||||
):
|
||||
state.first_token_time = state.last_time = time.time()
|
||||
state.last_completion_tokens = completion_tokens
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
@@ -1484,14 +1563,31 @@ class TokenizerManager:
|
||||
to_dump = self.dump_request_list
|
||||
self.dump_request_list = []
|
||||
|
||||
to_dump_with_server_args = {
|
||||
"server_args": self.server_args,
|
||||
"requests": to_dump,
|
||||
}
|
||||
|
||||
def background_task():
|
||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(to_dump, f)
|
||||
pickle.dump(to_dump_with_server_args, f)
|
||||
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
|
||||
current_time = time.time()
|
||||
self.crash_dump_request_list.append(
|
||||
(state.obj, out_dict, state.created_time, current_time)
|
||||
)
|
||||
# Remove requests older than 5 minutes based on finish time
|
||||
while (
|
||||
self.crash_dump_request_list
|
||||
and current_time - self.crash_dump_request_list[0][3] >= 300
|
||||
):
|
||||
self.crash_dump_request_list.popleft()
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
self.rid_to_state.pop(recv_obj.rid, None)
|
||||
|
||||
@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
||||
if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
|
||||
func.__self__.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1632,6 +1730,7 @@ class SignalHandler:
|
||||
logger.error(
|
||||
"Received sigquit from a child process. It usually means the child failed."
|
||||
)
|
||||
self.tokenizer_manager.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user