[Feature] Sglang Tracing: Fine-Grained Tracking for Request Latency - Part 1 (#9962)
Signed-off-by: Feng Su <sufeng@linux.alibaba.com> Signed-off-by: Huaixin Chang <changhuaixin@linux.alibaba.com> Signed-off-by: Peng Wang <rocking@linux.alibaba.com>
This commit is contained in:
118
docs/references/production_request_trace.md
Normal file
118
docs/references/production_request_trace.md
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server.
|
||||||
|
|
||||||
|
You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.
|
||||||
|
|
||||||
|
## Setup Guide
|
||||||
|
This section explains how to configure the request tracing and export the trace data.
|
||||||
|
1. Install the required packages and tools
|
||||||
|
* install Docker and Docker Compose
|
||||||
|
* install the dependencies
|
||||||
|
```bash
|
||||||
|
# enter the SGLang root directory
|
||||||
|
pip install -e "python[tracing]"
|
||||||
|
|
||||||
|
# or manually install the dependencies using pip
|
||||||
|
pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc
|
||||||
|
```
|
||||||
|
|
||||||
|
2. launch opentelemetry collector and jaeger
|
||||||
|
```bash
|
||||||
|
docker compose -f examples/monitoring/tracing_compose.yaml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
3. start your SGLang server with tracing enabled
|
||||||
|
```bash
|
||||||
|
python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 <other option>
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.
|
||||||
|
|
||||||
|
4. raise some requests
|
||||||
|
5. Observe whether trace data is being exported
|
||||||
|
* Access port 16686 of Jaeger using a web browser to visualize the request traces.
|
||||||
|
* The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI.
|
||||||
|
|
||||||
|
## How to add Tracing for slices you're interested in?
|
||||||
|
We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below.
|
||||||
|
|
||||||
|
1. initialization
|
||||||
|
|
||||||
|
Every process involved in tracing during the initialization phase should execute:
|
||||||
|
```python
|
||||||
|
process_tracing_init(oltp_traces_endpoint, server_name)
|
||||||
|
```
|
||||||
|
The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.
|
||||||
|
|
||||||
|
Every thread involved in tracing during the initialization phase should execute:
|
||||||
|
```python
|
||||||
|
trace_set_thread_info("thread label", tp_rank, dp_rank)
|
||||||
|
```
|
||||||
|
The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view.
|
||||||
|
|
||||||
|
2. Mark the beginning and end of a request
|
||||||
|
```
|
||||||
|
trace_req_start(rid, bootstrap_room)
|
||||||
|
trace_req_finish(rid)
|
||||||
|
```
|
||||||
|
These two APIs must be called within the same process, for example, in the tokenizer.
|
||||||
|
|
||||||
|
3. Add tracing for slice
|
||||||
|
|
||||||
|
* Add slice tracing normally:
|
||||||
|
```python
|
||||||
|
trace_slice_start("slice A", rid)
|
||||||
|
trace_slice_end("slice A", rid)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end.
|
||||||
|
<br>Note: Anonymous slices must not be nested.
|
||||||
|
```python
|
||||||
|
trace_slice_start("", rid, anonymous = True)
|
||||||
|
trace_slice_end("slice A", rid)
|
||||||
|
```
|
||||||
|
|
||||||
|
- In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed.
|
||||||
|
```python
|
||||||
|
trace_slice_start("", rid, anonymous = True)
|
||||||
|
trace_slice_end("slice A", rid, auto_next_anon = True)
|
||||||
|
trace_slice_end("slice B", rid, auto_next_anon = True)
|
||||||
|
trace_slice_end("slice C", rid, auto_next_anon = True)
|
||||||
|
trace_slice_end("slice D", rid)
|
||||||
|
```
|
||||||
|
- The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated.
|
||||||
|
```python
|
||||||
|
trace_slice_end("slice D", rid, thread_finish_flag = True)
|
||||||
|
```
|
||||||
|
|
||||||
|
4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated.
|
||||||
|
- sender: Execute the following code before sending the request to another thread via ZMQ
|
||||||
|
```python
|
||||||
|
trace_context = trace_get_proc_propagate_context(rid)
|
||||||
|
req.trace_context = trace_context
|
||||||
|
```
|
||||||
|
- receiver: Execute the following code after receiving the request via ZMQ
|
||||||
|
```python
|
||||||
|
trace_set_proc_propagate_context(rid, req.trace_context)
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to Extend the Tracing Framework to Support Complex Tracing Scenarios
|
||||||
|
|
||||||
|
The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.
|
||||||
|
|
||||||
|
The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure.
|
||||||
|
|
||||||
|
The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows:
|
||||||
|
```
|
||||||
|
SglangTraceReqContext (req_id="req-123")
|
||||||
|
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
|
||||||
|
│ └── SglangTraceSliceContext (name="prefill") # cur slice
|
||||||
|
|
|
||||||
|
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
|
||||||
|
└── SglangTraceSliceContext (name="prefill") # cur slice
|
||||||
|
```
|
||||||
|
|
||||||
|
Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends.
|
||||||
|
|
||||||
|
In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.
|
||||||
|
|
||||||
|
When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.
|
||||||
38
examples/monitoring/opentelemetry.yaml
Normal file
38
examples/monitoring/opentelemetry.yaml
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
receivers:
|
||||||
|
otlp:
|
||||||
|
protocols:
|
||||||
|
grpc:
|
||||||
|
endpoint: 0.0.0.0:4317
|
||||||
|
http:
|
||||||
|
endpoint: 0.0.0.0:4318
|
||||||
|
processors:
|
||||||
|
batch:
|
||||||
|
|
||||||
|
exporters:
|
||||||
|
otlp:
|
||||||
|
endpoint: jaeger:4317
|
||||||
|
tls:
|
||||||
|
insecure: true
|
||||||
|
file:
|
||||||
|
path: /tmp/otel_trace.json
|
||||||
|
|
||||||
|
extensions:
|
||||||
|
health_check:
|
||||||
|
pprof:
|
||||||
|
zpages:
|
||||||
|
|
||||||
|
service:
|
||||||
|
extensions: [health_check, pprof, zpages]
|
||||||
|
pipelines:
|
||||||
|
traces:
|
||||||
|
receivers: [otlp]
|
||||||
|
processors: [batch]
|
||||||
|
exporters: [otlp, file]
|
||||||
|
metrics:
|
||||||
|
receivers: [otlp]
|
||||||
|
processors: [batch]
|
||||||
|
exporters: [otlp]
|
||||||
|
logs:
|
||||||
|
receivers: [otlp]
|
||||||
|
processors: [batch]
|
||||||
|
exporters: [otlp]
|
||||||
21
examples/monitoring/tracing_compose.yaml
Normal file
21
examples/monitoring/tracing_compose.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
services:
|
||||||
|
otel-collector:
|
||||||
|
image: docker.io/otel/opentelemetry-collector
|
||||||
|
volumes:
|
||||||
|
- ./opentelemetry.yaml:/etc/otelcol/config.yaml
|
||||||
|
- /tmp:/tmp
|
||||||
|
ports:
|
||||||
|
- "4317:4317" # OTLP gRPC
|
||||||
|
- "4318:4318" # OTLP HTTP
|
||||||
|
depends_on:
|
||||||
|
- jaeger
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
jaeger:
|
||||||
|
image: jaegertracing/all-in-one
|
||||||
|
container_name: jaeger
|
||||||
|
ports:
|
||||||
|
- "16686:16686"
|
||||||
|
environment:
|
||||||
|
- COLLECTOR_OTLP_ENABLED=true
|
||||||
|
restart: unless-stopped
|
||||||
@@ -56,6 +56,13 @@ runtime_common = [
|
|||||||
"xgrammar==0.1.24",
|
"xgrammar==0.1.24",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tracing = [
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-exporter-otlp",
|
||||||
|
"opentelemetry-exporter-otlp-proto-grpc",
|
||||||
|
]
|
||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.3.9.post2",
|
"sgl-kernel==0.3.9.post2",
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
@@ -138,6 +140,12 @@ class Engine(EngineBase):
|
|||||||
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if server_args.enable_trace:
|
||||||
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||||
|
if server_args.disaggregation_mode == "null":
|
||||||
|
thread_label = "Tokenizer"
|
||||||
|
trace_set_thread_info(thread_label)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
|||||||
|
|
||||||
import setproctitle
|
import setproctitle
|
||||||
|
|
||||||
|
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
@@ -179,6 +181,13 @@ async def init_multi_tokenizer() -> ServerArgs:
|
|||||||
scheduler_info=scheduler_info,
|
scheduler_info=scheduler_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if server_args.enable_trace:
|
||||||
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||||
|
if server_args.disaggregation_mode == "null":
|
||||||
|
thread_label = f"MultiTokenizer-{tokenizer_manager.worker_id}"
|
||||||
|
trace_set_thread_info(thread_label)
|
||||||
|
|
||||||
return server_args
|
return server_args
|
||||||
|
|
||||||
|
|
||||||
@@ -1203,6 +1212,12 @@ def launch_server(
|
|||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if server_args.enable_trace:
|
||||||
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||||
|
if server_args.disaggregation_mode == "null":
|
||||||
|
thread_label = "Tokenizer"
|
||||||
|
trace_set_thread_info(thread_label)
|
||||||
|
|
||||||
set_global_state(
|
set_global_state(
|
||||||
_GlobalState(
|
_GlobalState(
|
||||||
tokenizer_manager=tokenizer_manager,
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
|||||||
@@ -605,6 +605,9 @@ class TokenizedGenerateReqInput:
|
|||||||
# Image gen grpc migration
|
# Image gen grpc migration
|
||||||
return_bytes: bool = False
|
return_bytes: bool = False
|
||||||
|
|
||||||
|
# tracing context
|
||||||
|
trace_context: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenizedGenerateReqInput:
|
class BatchTokenizedGenerateReqInput:
|
||||||
@@ -654,6 +657,9 @@ class EmbeddingReqInput:
|
|||||||
# For background responses (OpenAI responses API)
|
# For background responses (OpenAI responses API)
|
||||||
background: bool = False
|
background: bool = False
|
||||||
|
|
||||||
|
# tracing context
|
||||||
|
trace_context: Optional[Dict] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
# at least one of text, input_ids, or image should be provided
|
# at least one of text, input_ids, or image should be provided
|
||||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||||
|
|||||||
@@ -149,6 +149,15 @@ from sglang.srt.parser.reasoning_parser import ReasoningParser
|
|||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
from sglang.srt.tracing.trace import (
|
||||||
|
process_tracing_init,
|
||||||
|
trace_event,
|
||||||
|
trace_set_proc_propagate_context,
|
||||||
|
trace_set_thread_info,
|
||||||
|
trace_slice,
|
||||||
|
trace_slice_end,
|
||||||
|
trace_slice_start,
|
||||||
|
)
|
||||||
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
@@ -826,6 +835,10 @@ class Scheduler(
|
|||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
for req in batch.reqs:
|
||||||
|
trace_event("schedule", req.rid)
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
@@ -847,6 +860,10 @@ class Scheduler(
|
|||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
for req in batch.reqs:
|
||||||
|
trace_event("schedule", req.rid)
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
batch.launch_done = threading.Event()
|
batch.launch_done = threading.Event()
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
@@ -1110,6 +1127,12 @@ class Scheduler(
|
|||||||
self.tp_cpu_group,
|
self.tp_cpu_group,
|
||||||
src=self.tp_group.ranks[0],
|
src=self.tp_group.ranks[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for req in recv_reqs:
|
||||||
|
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
|
||||||
|
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
||||||
|
trace_slice_start("", req.rid, anonymous=True)
|
||||||
|
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
def process_input_requests(self, recv_reqs: List):
|
def process_input_requests(self, recv_reqs: List):
|
||||||
@@ -1347,6 +1370,7 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
self._prefetch_kvcache(req)
|
self._prefetch_kvcache(req)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||||
|
|
||||||
def _prefetch_kvcache(self, req: Req):
|
def _prefetch_kvcache(self, req: Req):
|
||||||
if self.enable_hicache_storage:
|
if self.enable_hicache_storage:
|
||||||
@@ -1914,8 +1938,23 @@ class Scheduler(
|
|||||||
):
|
):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result, launch_done)
|
self.process_batch_result_decode(batch, result, launch_done)
|
||||||
|
for req in batch.reqs:
|
||||||
|
trace_slice(
|
||||||
|
"decode loop",
|
||||||
|
req.rid,
|
||||||
|
auto_next_anon=not req.finished(),
|
||||||
|
thread_finish_flag=req.finished(),
|
||||||
|
)
|
||||||
|
|
||||||
elif batch.forward_mode.is_extend():
|
elif batch.forward_mode.is_extend():
|
||||||
self.process_batch_result_prefill(batch, result, launch_done)
|
self.process_batch_result_prefill(batch, result, launch_done)
|
||||||
|
for req in batch.reqs:
|
||||||
|
trace_slice(
|
||||||
|
"prefill",
|
||||||
|
req.rid,
|
||||||
|
auto_next_anon=not req.finished(),
|
||||||
|
thread_finish_flag=req.finished(),
|
||||||
|
)
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
@@ -2600,6 +2639,12 @@ def run_scheduler_process(
|
|||||||
pipe_writer,
|
pipe_writer,
|
||||||
balance_meta: Optional[DPBalanceMeta] = None,
|
balance_meta: Optional[DPBalanceMeta] = None,
|
||||||
):
|
):
|
||||||
|
if server_args.enable_trace:
|
||||||
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||||
|
if server_args.disaggregation_mode == "null":
|
||||||
|
thread_label = "Scheduler"
|
||||||
|
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
||||||
|
|
||||||
if (numa_node := server_args.numa_node) is not None:
|
if (numa_node := server_args.numa_node) is not None:
|
||||||
numa_bind_to_node(numa_node[gpu_id])
|
numa_bind_to_node(numa_node[gpu_id])
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,13 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat
|
|||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.tracing.trace import (
|
||||||
|
trace_get_proc_propagate_context,
|
||||||
|
trace_req_finish,
|
||||||
|
trace_req_start,
|
||||||
|
trace_slice_end,
|
||||||
|
trace_slice_start,
|
||||||
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_gc_warning,
|
configure_gc_warning,
|
||||||
dataclass_to_string_truncated,
|
dataclass_to_string_truncated,
|
||||||
@@ -358,6 +365,24 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
# If it's a single value, add worker_id prefix
|
# If it's a single value, add worker_id prefix
|
||||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||||
|
|
||||||
|
if obj.is_single:
|
||||||
|
bootstrap_room = (
|
||||||
|
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
||||||
|
)
|
||||||
|
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
||||||
|
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
||||||
|
else:
|
||||||
|
for i in range(len(obj.rid)):
|
||||||
|
bootstrap_room = (
|
||||||
|
obj.bootstrap_room[i]
|
||||||
|
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
||||||
|
trace_slice_start(
|
||||||
|
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
||||||
|
)
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length, skip_names, _ = self.log_request_metadata
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -574,6 +599,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
|
|
||||||
self._validate_one_request(obj, input_ids)
|
self._validate_one_request(obj, input_ids)
|
||||||
|
trace_slice_end("tokenize", obj.rid)
|
||||||
return self._create_tokenized_object(
|
return self._create_tokenized_object(
|
||||||
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
||||||
)
|
)
|
||||||
@@ -752,6 +778,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
req, req.text, input_ids_list[i], None, None, token_type_ids
|
req, req.text, input_ids_list[i], None, None, token_type_ids
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
trace_slice_end("tokenize", req.rid)
|
||||||
logger.debug(f"Completed batch processing for {batch_size} requests")
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
||||||
return tokenized_objs
|
return tokenized_objs
|
||||||
|
|
||||||
@@ -779,9 +806,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||||
created_time: Optional[float] = None,
|
created_time: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
trace_slice_start("dispatch", obj.rid)
|
||||||
|
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||||
self.rid_to_state[obj.rid] = state
|
self.rid_to_state[obj.rid] = state
|
||||||
|
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _send_batch_request(
|
def _send_batch_request(
|
||||||
@@ -1429,6 +1459,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||||
state.finished_time = time.time()
|
state.finished_time = time.time()
|
||||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||||
|
|
||||||
|
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
|
||||||
|
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
# Mark ongoing LoRA request as finished.
|
# Mark ongoing LoRA request as finished.
|
||||||
|
|||||||
@@ -215,6 +215,8 @@ class ServerArgs:
|
|||||||
enable_request_time_stats_logging: bool = False
|
enable_request_time_stats_logging: bool = False
|
||||||
kv_events_config: Optional[str] = None
|
kv_events_config: Optional[str] = None
|
||||||
gc_warning_threshold_secs: float = 0.0
|
gc_warning_threshold_secs: float = 0.0
|
||||||
|
enable_trace: bool = False
|
||||||
|
oltp_traces_endpoint: str = "localhost:4317"
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
@@ -1390,6 +1392,17 @@ class ServerArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-trace",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable opentelemetry trace",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--oltp-traces-endpoint",
|
||||||
|
type=str,
|
||||||
|
default="localhost:4317",
|
||||||
|
help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
|
||||||
|
)
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
552
python/sglang/srt/tracing/trace.py
Normal file
552
python/sglang/srt/tracing/trace.py
Normal file
@@ -0,0 +1,552 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""package for sglang requests tracing"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
opentelemetry_imported = False
|
||||||
|
tracing_enabled = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from opentelemetry import context, propagate, trace
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||||
|
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider, id_generator
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
|
||||||
|
opentelemetry_imported = True
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class id_generator:
|
||||||
|
class IdGenerator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("opentelemetry package is not installed, tracing disabled")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SglangTraceThreadInfo:
|
||||||
|
host_id: str
|
||||||
|
pid: int
|
||||||
|
thread_label: str
|
||||||
|
tp_rank: int
|
||||||
|
dp_rank: int
|
||||||
|
tracer: trace.Tracer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SglangTraceSliceContext:
|
||||||
|
slice_name: str
|
||||||
|
span: Optional[trace.span.Span] = None
|
||||||
|
# When True, defers slice_name assignment until trace_slice_end()
|
||||||
|
anonymous: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SglangTraceThreadContext:
|
||||||
|
thread_info: SglangTraceThreadInfo
|
||||||
|
cur_slice_stack: List[SglangTraceSliceContext]
|
||||||
|
thread_span: Optional[trace.span.Span] = None
|
||||||
|
# Record the most recently completed span as the previous span for the next span to be created.
|
||||||
|
last_span_context: Optional[trace.span.SpanContext] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SglangTraceReqContext:
|
||||||
|
rid: str
|
||||||
|
start_time_ns: int
|
||||||
|
threads_context: Dict[int, SglangTraceThreadContext]
|
||||||
|
bootstrap_room: Optional[int] = None
|
||||||
|
|
||||||
|
# Indicates whether this instance is a replica from the main process.
|
||||||
|
# When True, root_span is None and only root_span_context is preserved.
|
||||||
|
is_copy: bool = False
|
||||||
|
root_span: Optional[trace.span.Span] = None
|
||||||
|
root_span_context: Optional[context.Context] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SglangTracePropagateContext:
|
||||||
|
root_span_context: context.Context
|
||||||
|
prev_span_context: Optional[trace.span.SpanContext]
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
carrier: dict[str, str] = {}
|
||||||
|
context.attach(self.root_span_context)
|
||||||
|
propagate.inject(carrier)
|
||||||
|
|
||||||
|
if self.prev_span_context:
|
||||||
|
return {
|
||||||
|
"root_span": carrier,
|
||||||
|
"prev_span": {
|
||||||
|
"span_id": self.prev_span_context.span_id,
|
||||||
|
"trace_id": self.prev_span_context.trace_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {"root_span": carrier, "prev_span": "None"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def instance_from_dict(cls, d):
|
||||||
|
if "root_span" not in d or "prev_span" not in d:
|
||||||
|
return None
|
||||||
|
|
||||||
|
carrier = d["root_span"]
|
||||||
|
root_span_context = propagate.extract(carrier)
|
||||||
|
|
||||||
|
if d["prev_span"] == "None":
|
||||||
|
prev_span_context = None
|
||||||
|
else:
|
||||||
|
prev_span_context = trace.span.SpanContext(
|
||||||
|
trace_id=d["prev_span"]["trace_id"],
|
||||||
|
span_id=d["prev_span"]["span_id"],
|
||||||
|
is_remote=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(root_span_context, prev_span_context)
|
||||||
|
|
||||||
|
|
||||||
|
class SglangTraceCustomIdGenerator(id_generator.IdGenerator):
|
||||||
|
"""
|
||||||
|
The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes,
|
||||||
|
hence a custom IdGenerator is implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.local_random = random.Random()
|
||||||
|
self.local_random.seed(time.time())
|
||||||
|
|
||||||
|
def generate_trace_id(self) -> int:
|
||||||
|
return self.local_random.getrandbits(64)
|
||||||
|
|
||||||
|
def generate_span_id(self) -> int:
|
||||||
|
return self.local_random.getrandbits(64)
|
||||||
|
|
||||||
|
|
||||||
|
# global variables
|
||||||
|
threads_info: Dict[int, SglangTraceThreadInfo] = {}
|
||||||
|
reqs_context: Dict[str, SglangTraceReqContext] = {}
|
||||||
|
|
||||||
|
__get_cur_time_ns = lambda: int(time.time() * 1e9)
|
||||||
|
|
||||||
|
|
||||||
|
def __get_host_id() -> str:
|
||||||
|
"""
|
||||||
|
In distributed tracing systems, obtain a unique node identifier
|
||||||
|
and inject it into all subsequently generated spans
|
||||||
|
to prevent PID conflicts between threads on different nodes.
|
||||||
|
"""
|
||||||
|
if os.path.exists("/etc/machine-id"):
|
||||||
|
try:
|
||||||
|
with open("/etc/machine-id", "r") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mac = uuid.getnode()
|
||||||
|
if mac != 0:
|
||||||
|
return uuid.UUID(int=mac).hex
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# Should be called by each tracked process.
|
||||||
|
def process_tracing_init(otlp_endpoint, server_name):
|
||||||
|
global tracing_enabled
|
||||||
|
global __get_cur_time_ns
|
||||||
|
if not opentelemetry_imported:
|
||||||
|
tracing_enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
resource = Resource.create(
|
||||||
|
attributes={
|
||||||
|
SERVICE_NAME: server_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tracer_provider = TracerProvider(
|
||||||
|
resource=resource, id_generator=SglangTraceCustomIdGenerator()
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = BatchSpanProcessor(
|
||||||
|
OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
|
||||||
|
)
|
||||||
|
tracer_provider.add_span_processor(processor)
|
||||||
|
trace.set_tracer_provider(tracer_provider)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f": initialize opentelemetry error:{e}")
|
||||||
|
logger.warning("pelease set correct otlp endpoint")
|
||||||
|
tracing_enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
if hasattr(time, "time_ns"):
|
||||||
|
__get_cur_time_ns = lambda: int(time.time_ns())
|
||||||
|
|
||||||
|
tracing_enabled = True
|
||||||
|
|
||||||
|
|
||||||
|
# Should be called by each tracked thread.
|
||||||
|
def trace_set_thread_info(
|
||||||
|
thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None
|
||||||
|
):
|
||||||
|
if not tracing_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
if pid in threads_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
threads_info[pid] = SglangTraceThreadInfo(
|
||||||
|
host_id=__get_host_id(),
|
||||||
|
pid=pid,
|
||||||
|
thread_label=thread_label,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
tracer=trace.get_tracer("sglang server"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def __create_thread_context(pid, req_span_context, ts: Optional[int] = None):
|
||||||
|
if pid not in threads_info:
|
||||||
|
trace_set_thread_info("unknown")
|
||||||
|
|
||||||
|
thread_info = threads_info[pid]
|
||||||
|
thread_context = SglangTraceThreadContext(
|
||||||
|
thread_info=thread_info,
|
||||||
|
cur_slice_stack=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_name = f"{thread_info.thread_label}"
|
||||||
|
if thread_info.tp_rank is not None:
|
||||||
|
thread_name += f" [TP {thread_info.tp_rank}] "
|
||||||
|
thread_name += f"(host:{thread_info.host_id[:8]} | pid:{pid})"
|
||||||
|
ts = ts or __get_cur_time_ns()
|
||||||
|
thread_context.thread_span = thread_context.thread_info.tracer.start_span(
|
||||||
|
name=thread_name,
|
||||||
|
start_time=ts,
|
||||||
|
context=req_span_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if thread_info.tp_rank is not None:
|
||||||
|
thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank})
|
||||||
|
|
||||||
|
thread_context.thread_span.set_attributes(
|
||||||
|
{
|
||||||
|
"host_id": thread_info.host_id,
|
||||||
|
"pid": thread_info.pid,
|
||||||
|
"thread_label": thread_info.thread_label,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return thread_context
|
||||||
|
|
||||||
|
|
||||||
|
def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
|
||||||
|
if not tracing_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
if rid not in reqs_context or not reqs_context[rid].root_span_context:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
prev_span_context = None
|
||||||
|
thread_context = reqs_context[rid].threads_context[pid]
|
||||||
|
if thread_context.cur_slice_stack:
|
||||||
|
cur_slice_info = thread_context.cur_slice_stack[0]
|
||||||
|
prev_span_context = cur_slice_info.span.get_span_context()
|
||||||
|
elif thread_context.last_span_context:
|
||||||
|
prev_span_context = thread_context.last_span_context
|
||||||
|
|
||||||
|
trace_context = SglangTracePropagateContext(
|
||||||
|
reqs_context[rid].root_span_context, prev_span_context
|
||||||
|
)
|
||||||
|
return trace_context.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]]):
|
||||||
|
if not tracing_enabled:
|
||||||
|
return
|
||||||
|
if not trace_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
trace_context = SglangTracePropagateContext.instance_from_dict(trace_context)
|
||||||
|
if not trace_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
# Create a copy of the request context
|
||||||
|
if rid not in reqs_context:
|
||||||
|
reqs_context[rid] = SglangTraceReqContext(
|
||||||
|
rid=rid,
|
||||||
|
start_time_ns=__get_cur_time_ns(),
|
||||||
|
threads_context={},
|
||||||
|
root_span_context=trace_context.root_span_context,
|
||||||
|
is_copy=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
|
||||||
|
if pid in reqs_context[rid].threads_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create new thread context.
|
||||||
|
reqs_context[rid].threads_context[pid] = __create_thread_context(
|
||||||
|
pid,
|
||||||
|
trace_context.root_span_context,
|
||||||
|
reqs_context[rid].start_time_ns,
|
||||||
|
)
|
||||||
|
|
||||||
|
reqs_context[rid].threads_context[
|
||||||
|
pid
|
||||||
|
].last_span_context = trace_context.prev_span_context
|
||||||
|
|
||||||
|
|
||||||
|
def trace_req_start(
|
||||||
|
rid: str,
|
||||||
|
bootstrap_room: Optional[int] = None,
|
||||||
|
ts: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if not tracing_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
|
||||||
|
ts = ts or __get_cur_time_ns()
|
||||||
|
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
if pid not in threads_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
# create req context and root span
|
||||||
|
reqs_context[rid] = SglangTraceReqContext(
|
||||||
|
rid=rid,
|
||||||
|
start_time_ns=ts,
|
||||||
|
threads_context={},
|
||||||
|
bootstrap_room=bootstrap_room,
|
||||||
|
is_copy=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop the worker_id added by MultiTokenizer
|
||||||
|
orig_rid = rid.split("_")[-1]
|
||||||
|
tracer = threads_info[pid].tracer
|
||||||
|
root_span = tracer.start_span(
|
||||||
|
name=f"Req {orig_rid[:8]}",
|
||||||
|
start_time=ts,
|
||||||
|
)
|
||||||
|
|
||||||
|
root_span.set_attributes(
|
||||||
|
{
|
||||||
|
"rid": rid,
|
||||||
|
"bootstrap_room": bootstrap_room if bootstrap_room else "None",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
reqs_context[rid].root_span = root_span
|
||||||
|
reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
|
||||||
|
|
||||||
|
# create thread context and thread span
|
||||||
|
reqs_context[rid].threads_context[pid] = __create_thread_context(
|
||||||
|
pid,
|
||||||
|
reqs_context[rid].root_span_context,
|
||||||
|
ts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def trace_req_finish(
|
||||||
|
rid: str, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
if not tracing_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
if rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
req_context = reqs_context[rid]
|
||||||
|
ts = ts or __get_cur_time_ns()
|
||||||
|
|
||||||
|
# End all unclosed thread spans.
|
||||||
|
for thread_context in req_context.threads_context.values():
|
||||||
|
thread_context.thread_span.end(end_time=ts)
|
||||||
|
|
||||||
|
if attrs:
|
||||||
|
req_context.root_span.set_attributes(attrs)
|
||||||
|
|
||||||
|
req_context.root_span.end(end_time=ts)
|
||||||
|
|
||||||
|
del reqs_context[rid]
|
||||||
|
|
||||||
|
|
||||||
|
def trace_slice_start(
|
||||||
|
name: str,
|
||||||
|
rid: str,
|
||||||
|
ts: Optional[int] = None,
|
||||||
|
anonymous: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
if not tracing_enabled or rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
if pid not in reqs_context[rid].threads_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_context = reqs_context[rid].threads_context[pid]
|
||||||
|
|
||||||
|
ts = ts or __get_cur_time_ns()
|
||||||
|
|
||||||
|
slice_info = SglangTraceSliceContext(
|
||||||
|
slice_name=name,
|
||||||
|
anonymous=anonymous,
|
||||||
|
)
|
||||||
|
|
||||||
|
# find prev slice
|
||||||
|
prev_span_context = None
|
||||||
|
if not thread_context.cur_slice_stack:
|
||||||
|
if thread_context.last_span_context:
|
||||||
|
prev_span_context = thread_context.last_span_context
|
||||||
|
|
||||||
|
parent_span = thread_context.thread_span
|
||||||
|
if thread_context.cur_slice_stack:
|
||||||
|
parent_span = thread_context.cur_slice_stack[-1].span
|
||||||
|
|
||||||
|
parent_span_context = trace.set_span_in_context(parent_span)
|
||||||
|
span = thread_context.thread_info.tracer.start_span(
|
||||||
|
name=slice_info.slice_name,
|
||||||
|
start_time=ts,
|
||||||
|
context=parent_span_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prev_span_context:
|
||||||
|
span.add_link(prev_span_context)
|
||||||
|
|
||||||
|
slice_info.span = span
|
||||||
|
|
||||||
|
thread_context.cur_slice_stack.append(slice_info)
|
||||||
|
|
||||||
|
|
||||||
|
def trace_slice_end(
|
||||||
|
name: str,
|
||||||
|
rid: str,
|
||||||
|
ts: Optional[int] = None,
|
||||||
|
attrs: Optional[Dict[str, Any]] = None,
|
||||||
|
auto_next_anon: bool = False,
|
||||||
|
thread_finish_flag: bool = False,
|
||||||
|
):
|
||||||
|
rid = str(rid)
|
||||||
|
if not tracing_enabled or rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
if pid not in reqs_context[rid].threads_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_context = reqs_context[rid].threads_context[pid]
|
||||||
|
|
||||||
|
if not thread_context.cur_slice_stack:
|
||||||
|
logger.warning(f"No matching with the SLICE_START event{name} is required.")
|
||||||
|
return
|
||||||
|
|
||||||
|
ts = ts or __get_cur_time_ns()
|
||||||
|
slice_info = thread_context.cur_slice_stack[-1]
|
||||||
|
span = slice_info.span
|
||||||
|
|
||||||
|
if slice_info.anonymous:
|
||||||
|
span.update_name(name)
|
||||||
|
else:
|
||||||
|
span = slice_info.span
|
||||||
|
if slice_info.slice_name != name:
|
||||||
|
span.set_status(trace.Status(trace.StatusCode.ERROR))
|
||||||
|
logger.warning(f"Slice name mismatch: {name} != {slice_info.slice_name}")
|
||||||
|
|
||||||
|
if attrs:
|
||||||
|
span.set_attributes(attrs)
|
||||||
|
|
||||||
|
span.end(end_time=ts)
|
||||||
|
|
||||||
|
thread_context.cur_slice_stack.pop()
|
||||||
|
if len(thread_context.cur_slice_stack) == 0:
|
||||||
|
thread_context.last_span_context = span.get_span_context()
|
||||||
|
|
||||||
|
# If this is the last slice in the thread,
|
||||||
|
# release the thread context and check whether to release the request context.
|
||||||
|
if thread_finish_flag:
|
||||||
|
thread_context.thread_span.end(end_time=ts)
|
||||||
|
del reqs_context[rid].threads_context[pid]
|
||||||
|
if reqs_context[rid].is_copy and not reqs_context[rid].threads_context:
|
||||||
|
del reqs_context[rid]
|
||||||
|
return
|
||||||
|
|
||||||
|
if auto_next_anon:
|
||||||
|
trace_slice_start("", rid, ts, True)
|
||||||
|
|
||||||
|
|
||||||
|
# alias
|
||||||
|
trace_slice = trace_slice_end
|
||||||
|
|
||||||
|
|
||||||
|
# Add event to the current slice on the same thread with the same rid.
|
||||||
|
def trace_event(name: str, rid: str, ts: Optional[int] = None):
|
||||||
|
if not tracing_enabled or rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
if pid not in reqs_context[rid].threads_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_context = reqs_context[rid].threads_context[pid]
|
||||||
|
|
||||||
|
if not thread_context.cur_slice_stack:
|
||||||
|
logger.warning(f"No slice is currently being traced.")
|
||||||
|
return
|
||||||
|
|
||||||
|
ts = ts or __get_cur_time_ns()
|
||||||
|
|
||||||
|
slice_info = thread_context.cur_slice_stack[-1]
|
||||||
|
slice_info.span.add_event(name=name, timestamp=ts)
|
||||||
|
|
||||||
|
|
||||||
|
# Add attrs to the current slice on the same thread with the same rid.
|
||||||
|
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
|
||||||
|
if not tracing_enabled or rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
rid = str(rid)
|
||||||
|
pid = threading.get_native_id()
|
||||||
|
if pid not in reqs_context[rid].threads_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_context = reqs_context[rid].threads_context[pid]
|
||||||
|
|
||||||
|
if not thread_context.cur_slice_stack:
|
||||||
|
logger.warning(f"No slice is currently being traced.")
|
||||||
|
return
|
||||||
|
|
||||||
|
slice_info = thread_context.cur_slice_stack[-1]
|
||||||
|
slice_info.span.set_attributes(attrs)
|
||||||
273
test/srt/test_tracing.py
Normal file
273
test/srt/test_tracing.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from sglang import Engine
|
||||||
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||||
|
from sglang.srt.tracing.trace import *
|
||||||
|
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Req:
|
||||||
|
rid: int
|
||||||
|
trace_context: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrace(CustomTestCase):
|
||||||
|
def __launch_otel_jaeger(self):
|
||||||
|
cmd = [
|
||||||
|
"docker",
|
||||||
|
"compose",
|
||||||
|
"-f",
|
||||||
|
"../../examples/monitoring/tracing_compose.yaml",
|
||||||
|
"up",
|
||||||
|
"-d",
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
print("launch opentelemetry collector and jaeger docker err")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __stop_otel_jaeger(self):
|
||||||
|
cmd = [
|
||||||
|
"docker",
|
||||||
|
"compose",
|
||||||
|
"-f",
|
||||||
|
"../../examples/monitoring/tracing_compose.yaml",
|
||||||
|
"down",
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
print("stop opentelemetry collector and jaeger docker err")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __clear_trace_file(self):
|
||||||
|
try:
|
||||||
|
os.remove("/tmp/otel_trace.json")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_trace_enable(self):
|
||||||
|
self.__clear_trace_file()
|
||||||
|
assert self.__launch_otel_jaeger()
|
||||||
|
|
||||||
|
process = popen_launch_server(
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--enable-trace", "--oltp-traces-endpoint", "0.0.0.0:4317"],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make some requests to generate trace data
|
||||||
|
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for _ in response.iter_lines(decode_unicode=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# check trace file
|
||||||
|
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
|
||||||
|
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
assert self.__stop_otel_jaeger()
|
||||||
|
|
||||||
|
def test_trace_engine_enable(self):
|
||||||
|
self.__clear_trace_file()
|
||||||
|
assert self.__launch_otel_jaeger()
|
||||||
|
|
||||||
|
prompt = "Today is a sunny day and I like"
|
||||||
|
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
|
engine = Engine(
|
||||||
|
model_path=model_path,
|
||||||
|
random_seed=42,
|
||||||
|
enable_trace=True,
|
||||||
|
oltp_traces_endpoint="localhost:4317",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine.generate(prompt, sampling_params)
|
||||||
|
|
||||||
|
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# check trace file
|
||||||
|
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
|
||||||
|
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
|
||||||
|
finally:
|
||||||
|
engine.shutdown()
|
||||||
|
assert self.__stop_otel_jaeger()
|
||||||
|
|
||||||
|
def test_trace_engine_encode(self):
|
||||||
|
self.__clear_trace_file()
|
||||||
|
assert self.__launch_otel_jaeger()
|
||||||
|
|
||||||
|
prompt = "Today is a sunny day and I like"
|
||||||
|
model_path = "Qwen/Qwen2-7B"
|
||||||
|
|
||||||
|
engine = Engine(
|
||||||
|
model_path=model_path,
|
||||||
|
random_seed=42,
|
||||||
|
enable_trace=True,
|
||||||
|
oltp_traces_endpoint="localhost:4317",
|
||||||
|
is_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine.encode(prompt)
|
||||||
|
|
||||||
|
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# check trace file
|
||||||
|
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
|
||||||
|
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
|
||||||
|
finally:
|
||||||
|
engine.shutdown()
|
||||||
|
assert self.__stop_otel_jaeger()
|
||||||
|
|
||||||
|
def test_slice_trace_simple(self):
|
||||||
|
self.__clear_trace_file()
|
||||||
|
assert self.__launch_otel_jaeger()
|
||||||
|
try:
|
||||||
|
process_tracing_init("0.0.0.0:4317", "test")
|
||||||
|
trace_set_thread_info("Test")
|
||||||
|
trace_req_start(0)
|
||||||
|
trace_slice_start("test slice", 0)
|
||||||
|
time.sleep(1)
|
||||||
|
trace_slice_end("test slice", 0)
|
||||||
|
trace_req_finish(0)
|
||||||
|
|
||||||
|
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
|
||||||
|
time.sleep(10)
|
||||||
|
# check trace file
|
||||||
|
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
|
||||||
|
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
|
||||||
|
finally:
|
||||||
|
assert self.__stop_otel_jaeger()
|
||||||
|
|
||||||
|
def test_slice_trace_complex(self):
|
||||||
|
self.__clear_trace_file()
|
||||||
|
assert self.__launch_otel_jaeger()
|
||||||
|
try:
|
||||||
|
process_tracing_init("0.0.0.0:4317", "test")
|
||||||
|
trace_set_thread_info("Test")
|
||||||
|
trace_req_start(0)
|
||||||
|
trace_slice_start("", 0, anonymous=True)
|
||||||
|
time.sleep(1)
|
||||||
|
trace_slice_end("slice A", 0, auto_next_anon=True)
|
||||||
|
time.sleep(1)
|
||||||
|
trace_slice_end("slice B", 0, auto_next_anon=True)
|
||||||
|
time.sleep(1)
|
||||||
|
trace_slice_end("slice C", 0, thread_finish_flag=True)
|
||||||
|
trace_req_finish(0)
|
||||||
|
|
||||||
|
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
|
||||||
|
time.sleep(10)
|
||||||
|
# check trace file
|
||||||
|
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
|
||||||
|
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
|
||||||
|
finally:
|
||||||
|
assert self.__stop_otel_jaeger()
|
||||||
|
|
||||||
|
def test_trace_context_propagete(self):
|
||||||
|
def __process_work():
|
||||||
|
process_tracing_init("0.0.0.0:4317", "test")
|
||||||
|
trace_set_thread_info("Sub Process")
|
||||||
|
|
||||||
|
context = zmq.Context(2)
|
||||||
|
recv_from_main = get_zmq_socket(
|
||||||
|
context, zmq.PULL, "ipc:///tmp/zmq_test.ipc", True
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
req = recv_from_main.recv_pyobj()
|
||||||
|
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
||||||
|
trace_slice_start("work", req.rid)
|
||||||
|
time.sleep(1)
|
||||||
|
trace_slice_end("work", req.rid, thread_finish_flag=True)
|
||||||
|
finally:
|
||||||
|
recv_from_main.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
self.__clear_trace_file()
|
||||||
|
assert self.__launch_otel_jaeger()
|
||||||
|
|
||||||
|
context = zmq.Context(2)
|
||||||
|
send_to_subproc = get_zmq_socket(
|
||||||
|
context, zmq.PUSH, "ipc:///tmp/zmq_test.ipc", False
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
process_tracing_init("0.0.0.0:4317", "test")
|
||||||
|
trace_set_thread_info("Main Process")
|
||||||
|
|
||||||
|
subproc = mp.Process(target=__process_work)
|
||||||
|
subproc.start()
|
||||||
|
|
||||||
|
# sleep for a few second to ensure subprocess init
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
req = Req(rid=0)
|
||||||
|
trace_req_start(req.rid)
|
||||||
|
trace_slice_start("dispatch", req.rid)
|
||||||
|
time.sleep(1)
|
||||||
|
req.trace_context = trace_get_proc_propagate_context(req.rid)
|
||||||
|
send_to_subproc.send_pyobj(req)
|
||||||
|
trace_slice_end("dispatch", req.rid)
|
||||||
|
|
||||||
|
subproc.join()
|
||||||
|
trace_req_finish(req.rid)
|
||||||
|
|
||||||
|
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
|
||||||
|
time.sleep(10)
|
||||||
|
# check trace file
|
||||||
|
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
|
||||||
|
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
send_to_subproc.close()
|
||||||
|
context.term()
|
||||||
|
assert self.__stop_otel_jaeger()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user