diff --git a/docs/references/production_request_trace.md b/docs/references/production_request_trace.md new file mode 100644 index 000000000..928e5fd3f --- /dev/null +++ b/docs/references/production_request_trace.md @@ -0,0 +1,118 @@ +SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server. + +You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965. + +## Setup Guide +This section explains how to configure the request tracing and export the trace data. +1. Install the required packages and tools + * install Docker and Docker Compose + * install the dependencies + ```bash + # enter the SGLang root directory + pip install -e "python[tracing]" + + # or manually install the dependencies using pip + pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc + ``` + +2. launch opentelemetry collector and jaeger + ```bash + docker compose -f examples/monitoring/tracing_compose.yaml up -d + ``` + +3. start your SGLang server with tracing enabled + ```bash + python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 + ``` + + Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. + +4. raise some requests +5. Observe whether trace data is being exported + * Access port 16686 of Jaeger using a web browser to visualize the request traces. + * The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI. + +## How to add Tracing for slices you're interested in? +We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below. + +1. initialization + + Every process involved in tracing during the initialization phase should execute: + ```python + process_tracing_init(oltp_traces_endpoint, server_name) + ``` + The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes. + + Every thread involved in tracing during the initialization phase should execute: + ```python + trace_set_thread_info("thread label", tp_rank, dp_rank) + ``` + The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view. + +2. Mark the beginning and end of a request + ``` + trace_req_start(rid, bootstrap_room) + trace_req_finish(rid) + ``` + These two APIs must be called within the same process, for example, in the tokenizer. + +3. Add tracing for slice + + * Add slice tracing normally: + ```python + trace_slice_start("slice A", rid) + trace_slice_end("slice A", rid) + ``` + + - Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end. +
Note: Anonymous slices must not be nested. + ```python + trace_slice_start("", rid, anonymous = True) + trace_slice_end("slice A", rid) + ``` + + - In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed. + ```python + trace_slice_start("", rid, anonymous = True) + trace_slice_end("slice A", rid, auto_next_anon = True) + trace_slice_end("slice B", rid, auto_next_anon = True) + trace_slice_end("slice C", rid, auto_next_anon = True) + trace_slice_end("slice D", rid) + ``` + - The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated. + ```python + trace_slice_end("slice D", rid, thread_finish_flag = True) + ``` + +4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated. + - sender: Execute the following code before sending the request to another thread via ZMQ + ```python + trace_context = trace_get_proc_propagate_context(rid) + req.trace_context = trace_context + ``` + - receiver: Execute the following code after receiving the request via ZMQ + ```python + trace_set_proc_propagate_context(rid, req.trace_context) + ``` + +## How to Extend the Tracing Framework to Support Complex Tracing Scenarios + +The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles. + +The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure. + +The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows: +``` +SglangTraceReqContext (req_id="req-123") +├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0) +│ └── SglangTraceSliceContext (name="prefill") # cur slice +| +└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1) + └── SglangTraceSliceContext (name="prefill") # cur slice +``` + +Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends. + +In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow. + +When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span. diff --git a/examples/monitoring/opentelemetry.yaml b/examples/monitoring/opentelemetry.yaml new file mode 100644 index 000000000..8593d9182 --- /dev/null +++ b/examples/monitoring/opentelemetry.yaml @@ -0,0 +1,38 @@ +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + http: + endpoint: 0.0.0.0:4318 +processors: + batch: + +exporters: + otlp: + endpoint: jaeger:4317 + tls: + insecure: true + file: + path: /tmp/otel_trace.json + +extensions: + health_check: + pprof: + zpages: + +service: + extensions: [health_check, pprof, zpages] + pipelines: + traces: + receivers: [otlp] + processors: [batch] + exporters: [otlp, file] + metrics: + receivers: [otlp] + processors: [batch] + exporters: [otlp] + logs: + receivers: [otlp] + processors: [batch] + exporters: [otlp] diff --git a/examples/monitoring/tracing_compose.yaml b/examples/monitoring/tracing_compose.yaml new file mode 100644 index 000000000..7ed1ecdda --- /dev/null +++ b/examples/monitoring/tracing_compose.yaml @@ -0,0 +1,21 @@ +services: + otel-collector: + image: docker.io/otel/opentelemetry-collector + volumes: + - ./opentelemetry.yaml:/etc/otelcol/config.yaml + - /tmp:/tmp + ports: + - "4317:4317" # OTLP gRPC + - "4318:4318" # OTLP HTTP + depends_on: + - jaeger + restart: unless-stopped + + jaeger: + image: jaegertracing/all-in-one + container_name: jaeger + ports: + - "16686:16686" + environment: + - COLLECTOR_OTLP_ENABLED=true + restart: unless-stopped diff --git a/python/pyproject.toml b/python/pyproject.toml index 060e918b4..403f68143 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -56,6 +56,13 @@ runtime_common = [ "xgrammar==0.1.24", ] +tracing = [ + "opentelemetry-sdk", + "opentelemetry-api", + "opentelemetry-exporter-otlp", + "opentelemetry-exporter-otlp-proto-grpc", +] + srt = [ "sglang[runtime_common]", "sgl-kernel==0.3.9.post2", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 91e0d695c..390e0985a 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -33,6 +33,8 @@ import zmq import zmq.asyncio from PIL.Image import Image +from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -138,6 +140,12 @@ class Engine(EngineBase): context, zmq.DEALER, self.port_args.rpc_ipc_name, True ) + if server_args.enable_trace: + process_tracing_init(server_args.oltp_traces_endpoint, "sglang") + if server_args.disaggregation_mode == "null": + thread_label = "Tokenizer" + trace_set_thread_info(thread_label) + def generate( self, # The input prompt. It can be a single prompt or a batch of prompts. diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 524db4693..28aa897f7 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -31,6 +31,8 @@ from typing import Any, AsyncIterator, Callable, Dict, List, Optional import setproctitle +from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -179,6 +181,13 @@ async def init_multi_tokenizer() -> ServerArgs: scheduler_info=scheduler_info, ) ) + + if server_args.enable_trace: + process_tracing_init(server_args.oltp_traces_endpoint, "sglang") + if server_args.disaggregation_mode == "null": + thread_label = f"MultiTokenizer-{tokenizer_manager.worker_id}" + trace_set_thread_info(thread_label) + return server_args @@ -1203,6 +1212,12 @@ def launch_server( server_args=server_args, ) + if server_args.enable_trace: + process_tracing_init(server_args.oltp_traces_endpoint, "sglang") + if server_args.disaggregation_mode == "null": + thread_label = "Tokenizer" + trace_set_thread_info(thread_label) + set_global_state( _GlobalState( tokenizer_manager=tokenizer_manager, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 093060174..152c5a915 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -605,6 +605,9 @@ class TokenizedGenerateReqInput: # Image gen grpc migration return_bytes: bool = False + # tracing context + trace_context: Optional[Dict] = None + @dataclass class BatchTokenizedGenerateReqInput: @@ -654,6 +657,9 @@ class EmbeddingReqInput: # For background responses (OpenAI responses API) background: bool = False + # tracing context + trace_context: Optional[Dict] = None + def normalize_batch_and_arguments(self): # at least one of text, input_ids, or image should be provided if self.text is None and self.input_ids is None and self.image_data is None: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c66354be0..be7bc6a4a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -149,6 +149,15 @@ from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.tracing.trace import ( + process_tracing_init, + trace_event, + trace_set_proc_propagate_context, + trace_set_thread_info, + trace_slice, + trace_slice_end, + trace_slice_start, +) from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.utils import ( DynamicGradMode, @@ -826,6 +835,10 @@ class Scheduler( batch = self.get_next_batch_to_run() self.cur_batch = batch + if batch: + for req in batch.reqs: + trace_event("schedule", req.rid) + if batch: result = self.run_batch(batch) self.process_batch_result(batch, result) @@ -847,6 +860,10 @@ class Scheduler( batch = self.get_next_batch_to_run() self.cur_batch = batch + if batch: + for req in batch.reqs: + trace_event("schedule", req.rid) + if batch: batch.launch_done = threading.Event() result = self.run_batch(batch) @@ -1110,6 +1127,12 @@ class Scheduler( self.tp_cpu_group, src=self.tp_group.ranks[0], ) + + for req in recv_reqs: + if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)): + trace_set_proc_propagate_context(req.rid, req.trace_context) + trace_slice_start("", req.rid, anonymous=True) + return recv_reqs def process_input_requests(self, recv_reqs: List): @@ -1347,6 +1370,7 @@ class Scheduler( else: self._prefetch_kvcache(req) self.waiting_queue.append(req) + trace_slice_end("process req", req.rid, auto_next_anon=True) def _prefetch_kvcache(self, req: Req): if self.enable_hicache_storage: @@ -1914,8 +1938,23 @@ class Scheduler( ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result, launch_done) + for req in batch.reqs: + trace_slice( + "decode loop", + req.rid, + auto_next_anon=not req.finished(), + thread_finish_flag=req.finished(), + ) + elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result, launch_done) + for req in batch.reqs: + trace_slice( + "prefill", + req.rid, + auto_next_anon=not req.finished(), + thread_finish_flag=req.finished(), + ) elif batch.forward_mode.is_idle(): if self.enable_overlap: self.tp_worker.resolve_last_batch_result(launch_done) @@ -2600,6 +2639,12 @@ def run_scheduler_process( pipe_writer, balance_meta: Optional[DPBalanceMeta] = None, ): + if server_args.enable_trace: + process_tracing_init(server_args.oltp_traces_endpoint, "sglang") + if server_args.disaggregation_mode == "null": + thread_label = "Scheduler" + trace_set_thread_info(thread_label, tp_rank, dp_rank) + if (numa_node := server_args.numa_node) is not None: numa_bind_to_node(numa_node[gpu_id]) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index be2e7c654..25b5fd87c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -82,6 +82,13 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.tracing.trace import ( + trace_get_proc_propagate_context, + trace_req_finish, + trace_req_start, + trace_slice_end, + trace_slice_start, +) from sglang.srt.utils import ( configure_gc_warning, dataclass_to_string_truncated, @@ -358,6 +365,24 @@ class TokenizerManager(TokenizerCommunicatorMixin): # If it's a single value, add worker_id prefix obj.rid = f"{self.worker_id}_{obj.rid}" + if obj.is_single: + bootstrap_room = ( + obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None + ) + trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9)) + trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) + else: + for i in range(len(obj.rid)): + bootstrap_room = ( + obj.bootstrap_room[i] + if hasattr(obj, "bootstrap_room") and obj.bootstrap_room + else None + ) + trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9)) + trace_slice_start( + "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True + ) + if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( @@ -574,6 +599,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): mm_inputs = None self._validate_one_request(obj, input_ids) + trace_slice_end("tokenize", obj.rid) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids ) @@ -752,6 +778,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): req, req.text, input_ids_list[i], None, None, token_type_ids ) ) + trace_slice_end("tokenize", req.rid) logger.debug(f"Completed batch processing for {batch_size} requests") return tokenized_objs @@ -779,9 +806,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], created_time: Optional[float] = None, ): + trace_slice_start("dispatch", obj.rid) + tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) self.send_to_scheduler.send_pyobj(tokenized_obj) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) self.rid_to_state[obj.rid] = state + trace_slice_end("dispatch", obj.rid, thread_finish_flag=True) return state def _send_batch_request( @@ -1429,6 +1459,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] state.finished_time = time.time() meta_info["e2e_latency"] = state.finished_time - state.created_time + + trace_req_finish(rid, ts=int(state.finished_time * 1e9)) + del self.rid_to_state[rid] # Mark ongoing LoRA request as finished. diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ce67d1f7b..925f51ea1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -215,6 +215,8 @@ class ServerArgs: enable_request_time_stats_logging: bool = False kv_events_config: Optional[str] = None gc_warning_threshold_secs: float = 0.0 + enable_trace: bool = False + oltp_traces_endpoint: str = "localhost:4317" # API related api_key: Optional[str] = None @@ -1390,6 +1392,17 @@ class ServerArgs: default=None, help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.", ) + parser.add_argument( + "--enable-trace", + action="store_true", + help="Enable opentelemetry trace", + ) + parser.add_argument( + "--oltp-traces-endpoint", + type=str, + default="localhost:4317", + help="Config opentelemetry collector endpoint if --enable-trace is set. format: :", + ) # API related parser.add_argument( diff --git a/python/sglang/srt/tracing/trace.py b/python/sglang/srt/tracing/trace.py new file mode 100644 index 000000000..b4ccbfa9f --- /dev/null +++ b/python/sglang/srt/tracing/trace.py @@ -0,0 +1,552 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""package for sglang requests tracing""" + +from __future__ import annotations + +import ctypes +import logging +import os +import random +import threading +import time +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) +opentelemetry_imported = False +tracing_enabled = False + +try: + from opentelemetry import context, propagate, trace + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.resources import SERVICE_NAME, Resource + from opentelemetry.sdk.trace import TracerProvider, id_generator + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + opentelemetry_imported = True +except ImportError: + + class id_generator: + class IdGenerator: + pass + + logger.info("opentelemetry package is not installed, tracing disabled") + + +@dataclass +class SglangTraceThreadInfo: + host_id: str + pid: int + thread_label: str + tp_rank: int + dp_rank: int + tracer: trace.Tracer + + +@dataclass +class SglangTraceSliceContext: + slice_name: str + span: Optional[trace.span.Span] = None + # When True, defers slice_name assignment until trace_slice_end() + anonymous: bool = False + + +@dataclass +class SglangTraceThreadContext: + thread_info: SglangTraceThreadInfo + cur_slice_stack: List[SglangTraceSliceContext] + thread_span: Optional[trace.span.Span] = None + # Record the most recently completed span as the previous span for the next span to be created. + last_span_context: Optional[trace.span.SpanContext] = None + + +@dataclass +class SglangTraceReqContext: + rid: str + start_time_ns: int + threads_context: Dict[int, SglangTraceThreadContext] + bootstrap_room: Optional[int] = None + + # Indicates whether this instance is a replica from the main process. + # When True, root_span is None and only root_span_context is preserved. + is_copy: bool = False + root_span: Optional[trace.span.Span] = None + root_span_context: Optional[context.Context] = None + + +@dataclass +class SglangTracePropagateContext: + root_span_context: context.Context + prev_span_context: Optional[trace.span.SpanContext] + + def to_dict(self): + carrier: dict[str, str] = {} + context.attach(self.root_span_context) + propagate.inject(carrier) + + if self.prev_span_context: + return { + "root_span": carrier, + "prev_span": { + "span_id": self.prev_span_context.span_id, + "trace_id": self.prev_span_context.trace_id, + }, + } + else: + return {"root_span": carrier, "prev_span": "None"} + + @classmethod + def instance_from_dict(cls, d): + if "root_span" not in d or "prev_span" not in d: + return None + + carrier = d["root_span"] + root_span_context = propagate.extract(carrier) + + if d["prev_span"] == "None": + prev_span_context = None + else: + prev_span_context = trace.span.SpanContext( + trace_id=d["prev_span"]["trace_id"], + span_id=d["prev_span"]["span_id"], + is_remote=True, + ) + + return cls(root_span_context, prev_span_context) + + +class SglangTraceCustomIdGenerator(id_generator.IdGenerator): + """ + The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes, + hence a custom IdGenerator is implemented. + """ + + def __init__(self): + super().__init__() + self.local_random = random.Random() + self.local_random.seed(time.time()) + + def generate_trace_id(self) -> int: + return self.local_random.getrandbits(64) + + def generate_span_id(self) -> int: + return self.local_random.getrandbits(64) + + +# global variables +threads_info: Dict[int, SglangTraceThreadInfo] = {} +reqs_context: Dict[str, SglangTraceReqContext] = {} + +__get_cur_time_ns = lambda: int(time.time() * 1e9) + + +def __get_host_id() -> str: + """ + In distributed tracing systems, obtain a unique node identifier + and inject it into all subsequently generated spans + to prevent PID conflicts between threads on different nodes. + """ + if os.path.exists("/etc/machine-id"): + try: + with open("/etc/machine-id", "r") as f: + return f.read().strip() + except: + pass + + mac = uuid.getnode() + if mac != 0: + return uuid.UUID(int=mac).hex + + return "unknown" + + +# Should be called by each tracked process. +def process_tracing_init(otlp_endpoint, server_name): + global tracing_enabled + global __get_cur_time_ns + if not opentelemetry_imported: + tracing_enabled = False + return + + try: + resource = Resource.create( + attributes={ + SERVICE_NAME: server_name, + } + ) + tracer_provider = TracerProvider( + resource=resource, id_generator=SglangTraceCustomIdGenerator() + ) + + processor = BatchSpanProcessor( + OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) + ) + tracer_provider.add_span_processor(processor) + trace.set_tracer_provider(tracer_provider) + except Exception as e: + logger.error(f": initialize opentelemetry error:{e}") + logger.warning("pelease set correct otlp endpoint") + tracing_enabled = False + return + + if hasattr(time, "time_ns"): + __get_cur_time_ns = lambda: int(time.time_ns()) + + tracing_enabled = True + + +# Should be called by each tracked thread. +def trace_set_thread_info( + thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None +): + if not tracing_enabled: + return + + pid = threading.get_native_id() + if pid in threads_info: + return + + threads_info[pid] = SglangTraceThreadInfo( + host_id=__get_host_id(), + pid=pid, + thread_label=thread_label, + tp_rank=tp_rank, + dp_rank=dp_rank, + tracer=trace.get_tracer("sglang server"), + ) + + +def __create_thread_context(pid, req_span_context, ts: Optional[int] = None): + if pid not in threads_info: + trace_set_thread_info("unknown") + + thread_info = threads_info[pid] + thread_context = SglangTraceThreadContext( + thread_info=thread_info, + cur_slice_stack=[], + ) + + thread_name = f"{thread_info.thread_label}" + if thread_info.tp_rank is not None: + thread_name += f" [TP {thread_info.tp_rank}] " + thread_name += f"(host:{thread_info.host_id[:8]} | pid:{pid})" + ts = ts or __get_cur_time_ns() + thread_context.thread_span = thread_context.thread_info.tracer.start_span( + name=thread_name, + start_time=ts, + context=req_span_context, + ) + + if thread_info.tp_rank is not None: + thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank}) + + thread_context.thread_span.set_attributes( + { + "host_id": thread_info.host_id, + "pid": thread_info.pid, + "thread_label": thread_info.thread_label, + } + ) + + return thread_context + + +def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]: + if not tracing_enabled: + return None + + rid = str(rid) + if rid not in reqs_context or not reqs_context[rid].root_span_context: + return None + + pid = threading.get_native_id() + prev_span_context = None + thread_context = reqs_context[rid].threads_context[pid] + if thread_context.cur_slice_stack: + cur_slice_info = thread_context.cur_slice_stack[0] + prev_span_context = cur_slice_info.span.get_span_context() + elif thread_context.last_span_context: + prev_span_context = thread_context.last_span_context + + trace_context = SglangTracePropagateContext( + reqs_context[rid].root_span_context, prev_span_context + ) + return trace_context.to_dict() + + +def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]]): + if not tracing_enabled: + return + if not trace_context: + return + + trace_context = SglangTracePropagateContext.instance_from_dict(trace_context) + if not trace_context: + return + + rid = str(rid) + # Create a copy of the request context + if rid not in reqs_context: + reqs_context[rid] = SglangTraceReqContext( + rid=rid, + start_time_ns=__get_cur_time_ns(), + threads_context={}, + root_span_context=trace_context.root_span_context, + is_copy=True, + ) + + pid = threading.get_native_id() + + if pid in reqs_context[rid].threads_context: + return + + # Create new thread context. + reqs_context[rid].threads_context[pid] = __create_thread_context( + pid, + trace_context.root_span_context, + reqs_context[rid].start_time_ns, + ) + + reqs_context[rid].threads_context[ + pid + ].last_span_context = trace_context.prev_span_context + + +def trace_req_start( + rid: str, + bootstrap_room: Optional[int] = None, + ts: Optional[int] = None, +): + if not tracing_enabled: + return + + rid = str(rid) + + ts = ts or __get_cur_time_ns() + + pid = threading.get_native_id() + if pid not in threads_info: + return + + # create req context and root span + reqs_context[rid] = SglangTraceReqContext( + rid=rid, + start_time_ns=ts, + threads_context={}, + bootstrap_room=bootstrap_room, + is_copy=False, + ) + + # Drop the worker_id added by MultiTokenizer + orig_rid = rid.split("_")[-1] + tracer = threads_info[pid].tracer + root_span = tracer.start_span( + name=f"Req {orig_rid[:8]}", + start_time=ts, + ) + + root_span.set_attributes( + { + "rid": rid, + "bootstrap_room": bootstrap_room if bootstrap_room else "None", + } + ) + + reqs_context[rid].root_span = root_span + reqs_context[rid].root_span_context = trace.set_span_in_context(root_span) + + # create thread context and thread span + reqs_context[rid].threads_context[pid] = __create_thread_context( + pid, + reqs_context[rid].root_span_context, + ts, + ) + + +def trace_req_finish( + rid: str, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None +): + if not tracing_enabled: + return + + rid = str(rid) + if rid not in reqs_context: + return + + req_context = reqs_context[rid] + ts = ts or __get_cur_time_ns() + + # End all unclosed thread spans. + for thread_context in req_context.threads_context.values(): + thread_context.thread_span.end(end_time=ts) + + if attrs: + req_context.root_span.set_attributes(attrs) + + req_context.root_span.end(end_time=ts) + + del reqs_context[rid] + + +def trace_slice_start( + name: str, + rid: str, + ts: Optional[int] = None, + anonymous: bool = False, +): + + rid = str(rid) + if not tracing_enabled or rid not in reqs_context: + return + + pid = threading.get_native_id() + if pid not in reqs_context[rid].threads_context: + return + + thread_context = reqs_context[rid].threads_context[pid] + + ts = ts or __get_cur_time_ns() + + slice_info = SglangTraceSliceContext( + slice_name=name, + anonymous=anonymous, + ) + + # find prev slice + prev_span_context = None + if not thread_context.cur_slice_stack: + if thread_context.last_span_context: + prev_span_context = thread_context.last_span_context + + parent_span = thread_context.thread_span + if thread_context.cur_slice_stack: + parent_span = thread_context.cur_slice_stack[-1].span + + parent_span_context = trace.set_span_in_context(parent_span) + span = thread_context.thread_info.tracer.start_span( + name=slice_info.slice_name, + start_time=ts, + context=parent_span_context, + ) + + if prev_span_context: + span.add_link(prev_span_context) + + slice_info.span = span + + thread_context.cur_slice_stack.append(slice_info) + + +def trace_slice_end( + name: str, + rid: str, + ts: Optional[int] = None, + attrs: Optional[Dict[str, Any]] = None, + auto_next_anon: bool = False, + thread_finish_flag: bool = False, +): + rid = str(rid) + if not tracing_enabled or rid not in reqs_context: + return + + pid = threading.get_native_id() + if pid not in reqs_context[rid].threads_context: + return + + thread_context = reqs_context[rid].threads_context[pid] + + if not thread_context.cur_slice_stack: + logger.warning(f"No matching with the SLICE_START event{name} is required.") + return + + ts = ts or __get_cur_time_ns() + slice_info = thread_context.cur_slice_stack[-1] + span = slice_info.span + + if slice_info.anonymous: + span.update_name(name) + else: + span = slice_info.span + if slice_info.slice_name != name: + span.set_status(trace.Status(trace.StatusCode.ERROR)) + logger.warning(f"Slice name mismatch: {name} != {slice_info.slice_name}") + + if attrs: + span.set_attributes(attrs) + + span.end(end_time=ts) + + thread_context.cur_slice_stack.pop() + if len(thread_context.cur_slice_stack) == 0: + thread_context.last_span_context = span.get_span_context() + + # If this is the last slice in the thread, + # release the thread context and check whether to release the request context. + if thread_finish_flag: + thread_context.thread_span.end(end_time=ts) + del reqs_context[rid].threads_context[pid] + if reqs_context[rid].is_copy and not reqs_context[rid].threads_context: + del reqs_context[rid] + return + + if auto_next_anon: + trace_slice_start("", rid, ts, True) + + +# alias +trace_slice = trace_slice_end + + +# Add event to the current slice on the same thread with the same rid. +def trace_event(name: str, rid: str, ts: Optional[int] = None): + if not tracing_enabled or rid not in reqs_context: + return + + rid = str(rid) + pid = threading.get_native_id() + if pid not in reqs_context[rid].threads_context: + return + + thread_context = reqs_context[rid].threads_context[pid] + + if not thread_context.cur_slice_stack: + logger.warning(f"No slice is currently being traced.") + return + + ts = ts or __get_cur_time_ns() + + slice_info = thread_context.cur_slice_stack[-1] + slice_info.span.add_event(name=name, timestamp=ts) + + +# Add attrs to the current slice on the same thread with the same rid. +def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): + if not tracing_enabled or rid not in reqs_context: + return + + rid = str(rid) + pid = threading.get_native_id() + if pid not in reqs_context[rid].threads_context: + return + + thread_context = reqs_context[rid].threads_context[pid] + + if not thread_context.cur_slice_stack: + logger.warning(f"No slice is currently being traced.") + return + + slice_info = thread_context.cur_slice_stack[-1] + slice_info.span.set_attributes(attrs) diff --git a/test/srt/test_tracing.py b/test/srt/test_tracing.py new file mode 100644 index 000000000..a3e6de6b5 --- /dev/null +++ b/test/srt/test_tracing.py @@ -0,0 +1,273 @@ +import multiprocessing as mp +import os +import subprocess +import time +import unittest +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import requests +import zmq + +from sglang import Engine +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput +from sglang.srt.tracing.trace import * +from sglang.srt.utils import get_zmq_socket, kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +@dataclass +class Req: + rid: int + trace_context: Optional[Dict[str, Any]] = None + + +class TestTrace(CustomTestCase): + def __launch_otel_jaeger(self): + cmd = [ + "docker", + "compose", + "-f", + "../../examples/monitoring/tracing_compose.yaml", + "up", + "-d", + ] + proc = subprocess.run(cmd) + + if proc.returncode != 0: + print("launch opentelemetry collector and jaeger docker err") + return False + return True + + def __stop_otel_jaeger(self): + cmd = [ + "docker", + "compose", + "-f", + "../../examples/monitoring/tracing_compose.yaml", + "down", + ] + proc = subprocess.run(cmd) + + if proc.returncode != 0: + print("stop opentelemetry collector and jaeger docker err") + return False + return True + + def __clear_trace_file(self): + try: + os.remove("/tmp/otel_trace.json") + except: + pass + + def test_trace_enable(self): + self.__clear_trace_file() + assert self.__launch_otel_jaeger() + + process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-trace", "--oltp-traces-endpoint", "0.0.0.0:4317"], + ) + + try: + # Make some requests to generate trace data + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") + self.assertEqual(response.status_code, 200) + + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + "stream": True, + }, + stream=True, + ) + for _ in response.iter_lines(decode_unicode=False): + pass + + # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. + time.sleep(10) + + # check trace file + assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" + assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + + finally: + kill_process_tree(process.pid) + assert self.__stop_otel_jaeger() + + def test_trace_engine_enable(self): + self.__clear_trace_file() + assert self.__launch_otel_jaeger() + + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = Engine( + model_path=model_path, + random_seed=42, + enable_trace=True, + oltp_traces_endpoint="localhost:4317", + ) + + try: + engine.generate(prompt, sampling_params) + + # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. + time.sleep(10) + + # check trace file + assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" + assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + finally: + engine.shutdown() + assert self.__stop_otel_jaeger() + + def test_trace_engine_encode(self): + self.__clear_trace_file() + assert self.__launch_otel_jaeger() + + prompt = "Today is a sunny day and I like" + model_path = "Qwen/Qwen2-7B" + + engine = Engine( + model_path=model_path, + random_seed=42, + enable_trace=True, + oltp_traces_endpoint="localhost:4317", + is_embedding=True, + ) + + try: + engine.encode(prompt) + + # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. + time.sleep(10) + + # check trace file + assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" + assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + finally: + engine.shutdown() + assert self.__stop_otel_jaeger() + + def test_slice_trace_simple(self): + self.__clear_trace_file() + assert self.__launch_otel_jaeger() + try: + process_tracing_init("0.0.0.0:4317", "test") + trace_set_thread_info("Test") + trace_req_start(0) + trace_slice_start("test slice", 0) + time.sleep(1) + trace_slice_end("test slice", 0) + trace_req_finish(0) + + # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. + time.sleep(10) + # check trace file + assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" + assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + finally: + assert self.__stop_otel_jaeger() + + def test_slice_trace_complex(self): + self.__clear_trace_file() + assert self.__launch_otel_jaeger() + try: + process_tracing_init("0.0.0.0:4317", "test") + trace_set_thread_info("Test") + trace_req_start(0) + trace_slice_start("", 0, anonymous=True) + time.sleep(1) + trace_slice_end("slice A", 0, auto_next_anon=True) + time.sleep(1) + trace_slice_end("slice B", 0, auto_next_anon=True) + time.sleep(1) + trace_slice_end("slice C", 0, thread_finish_flag=True) + trace_req_finish(0) + + # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. + time.sleep(10) + # check trace file + assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" + assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + finally: + assert self.__stop_otel_jaeger() + + def test_trace_context_propagete(self): + def __process_work(): + process_tracing_init("0.0.0.0:4317", "test") + trace_set_thread_info("Sub Process") + + context = zmq.Context(2) + recv_from_main = get_zmq_socket( + context, zmq.PULL, "ipc:///tmp/zmq_test.ipc", True + ) + + try: + req = recv_from_main.recv_pyobj() + trace_set_proc_propagate_context(req.rid, req.trace_context) + trace_slice_start("work", req.rid) + time.sleep(1) + trace_slice_end("work", req.rid, thread_finish_flag=True) + finally: + recv_from_main.close() + context.term() + + self.__clear_trace_file() + assert self.__launch_otel_jaeger() + + context = zmq.Context(2) + send_to_subproc = get_zmq_socket( + context, zmq.PUSH, "ipc:///tmp/zmq_test.ipc", False + ) + try: + process_tracing_init("0.0.0.0:4317", "test") + trace_set_thread_info("Main Process") + + subproc = mp.Process(target=__process_work) + subproc.start() + + # sleep for a few second to ensure subprocess init + time.sleep(1) + + req = Req(rid=0) + trace_req_start(req.rid) + trace_slice_start("dispatch", req.rid) + time.sleep(1) + req.trace_context = trace_get_proc_propagate_context(req.rid) + send_to_subproc.send_pyobj(req) + trace_slice_end("dispatch", req.rid) + + subproc.join() + trace_req_finish(req.rid) + + # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. + time.sleep(10) + # check trace file + assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" + assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + + finally: + send_to_subproc.close() + context.term() + assert self.__stop_otel_jaeger() + + +if __name__ == "__main__": + unittest.main()