[Metrics] Add KV events publishing (#6098)
This commit is contained in:
@@ -25,6 +25,7 @@ runtime_common = [
|
|||||||
"interegular",
|
"interegular",
|
||||||
"llguidance>=0.7.11,<0.8.0",
|
"llguidance>=0.7.11,<0.8.0",
|
||||||
"modelscope",
|
"modelscope",
|
||||||
|
"msgspec",
|
||||||
"ninja",
|
"ninja",
|
||||||
"orjson",
|
"orjson",
|
||||||
"packaging",
|
"packaging",
|
||||||
|
|||||||
357
python/sglang/srt/disaggregation/kv_events.py
Normal file
357
python/sglang/srt/disaggregation/kv_events.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2025 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
KV caching events
|
||||||
|
"""
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
from itertools import count
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import zmq
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EventBatch(
|
||||||
|
msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False, # type: ignore[call-arg]
|
||||||
|
):
|
||||||
|
ts: float
|
||||||
|
events: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheEvent(
|
||||||
|
msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False, # type: ignore[call-arg]
|
||||||
|
tag=True,
|
||||||
|
):
|
||||||
|
"""Base class for all KV cache-related events"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockStored(KVCacheEvent):
|
||||||
|
block_hashes: list[int]
|
||||||
|
parent_block_hash: Optional[int]
|
||||||
|
token_ids: list[int]
|
||||||
|
block_size: int
|
||||||
|
lora_id: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockRemoved(KVCacheEvent):
|
||||||
|
block_hashes: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AllBlocksCleared(KVCacheEvent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KVEventBatch(EventBatch):
|
||||||
|
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||||
|
|
||||||
|
|
||||||
|
class EventPublisher(ABC):
|
||||||
|
"""Lightweight publisher for EventBatch batches."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def publish(self, events: EventBatch) -> None:
|
||||||
|
"""Emit events in order.
|
||||||
|
|
||||||
|
Implementations should guarantee at-least-once delivery and
|
||||||
|
monotonic ordering (e.g., via sequence numbers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the publisher."""
|
||||||
|
|
||||||
|
|
||||||
|
class NullEventPublisher(EventPublisher):
|
||||||
|
"""No-op implementation (default when disabled)."""
|
||||||
|
|
||||||
|
def publish(self, events) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqEventPublisher(EventPublisher):
|
||||||
|
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
||||||
|
|
||||||
|
Spawns a separate thread to handle publishing from a queue.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
endpoint:
|
||||||
|
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
||||||
|
connect.
|
||||||
|
replay_endpoint:
|
||||||
|
Optional ROUTER address for replay requests. When given, subscribers can
|
||||||
|
request missed batches by sending the starting sequence number as an
|
||||||
|
8-byte big-endian integer.
|
||||||
|
buffer_steps:
|
||||||
|
Number of past batches to keep for replay.
|
||||||
|
hwm:
|
||||||
|
ZeroMQ high-water-mark for PUB socket.
|
||||||
|
max_queue_size:
|
||||||
|
Maximum number of events to buffer in memory.
|
||||||
|
topic:
|
||||||
|
Topic to publish events to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SHUTDOWN_TIMEOUT: float = 1.0
|
||||||
|
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str = "tcp://*:5557",
|
||||||
|
replay_endpoint: Optional[str] = None,
|
||||||
|
buffer_steps: int = 10_000,
|
||||||
|
hwm: int = 100_000,
|
||||||
|
max_queue_size: int = 100_000,
|
||||||
|
topic: str = "",
|
||||||
|
) -> None:
|
||||||
|
# Storage
|
||||||
|
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||||
|
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||||
|
|
||||||
|
# ZMQ sockets
|
||||||
|
self._ctx = zmq.Context.instance()
|
||||||
|
self._pub: Optional[zmq.Socket] = None
|
||||||
|
self._replay: Optional[zmq.Socket] = None
|
||||||
|
self._endpoint = endpoint
|
||||||
|
self._replay_endpoint = replay_endpoint
|
||||||
|
self._hwm = hwm
|
||||||
|
self._socket_setup()
|
||||||
|
|
||||||
|
# Payload
|
||||||
|
self._seq_gen = count()
|
||||||
|
self._topic_bytes = topic.encode("utf-8")
|
||||||
|
|
||||||
|
# Thread
|
||||||
|
self._running = True
|
||||||
|
logger.info("Starting ZMQ publisher thread")
|
||||||
|
|
||||||
|
self._thread = threading.Thread(
|
||||||
|
target=self._publisher_thread, daemon=True, name="zmq-publisher"
|
||||||
|
)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
def publish(self, events: EventBatch) -> None:
|
||||||
|
if not self._running:
|
||||||
|
raise RuntimeError("Publisher is closed")
|
||||||
|
self._event_queue.put(events)
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Stop the publisher thread and clean up resources."""
|
||||||
|
self._running = False
|
||||||
|
self._event_queue.put_nowait(None)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
pending_items = True
|
||||||
|
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
||||||
|
pending_items = not self._event_queue.empty()
|
||||||
|
if pending_items:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if pending_items:
|
||||||
|
logger.warning(
|
||||||
|
"Warning: Queue still has %s items after %s seconds timeout",
|
||||||
|
self._event_queue.qsize(),
|
||||||
|
self.SHUTDOWN_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._thread.is_alive():
|
||||||
|
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
||||||
|
|
||||||
|
# Clean up ZMQ resources
|
||||||
|
try:
|
||||||
|
if self._pub is not None:
|
||||||
|
self._pub.close(linger=0)
|
||||||
|
if self._replay is not None:
|
||||||
|
self._replay.close(linger=0)
|
||||||
|
finally:
|
||||||
|
pass # Do not terminate context; other sockets may use it
|
||||||
|
|
||||||
|
def _socket_setup(self) -> None:
|
||||||
|
"""Initialize sockets
|
||||||
|
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
||||||
|
"""
|
||||||
|
if self._pub is None:
|
||||||
|
self._pub = self._ctx.socket(zmq.PUB)
|
||||||
|
self._pub.set_hwm(self._hwm)
|
||||||
|
# Heuristic: bind if wildcard / * present, else connect.
|
||||||
|
# bind stable, connect volatile convention
|
||||||
|
if (
|
||||||
|
"*" in self._endpoint
|
||||||
|
or "::" in self._endpoint
|
||||||
|
or self._endpoint.startswith("ipc://")
|
||||||
|
or self._endpoint.startswith("inproc://")
|
||||||
|
):
|
||||||
|
self._pub.bind(self._endpoint)
|
||||||
|
else:
|
||||||
|
self._pub.connect(self._endpoint)
|
||||||
|
|
||||||
|
# Set up replay socket: use ROUTER
|
||||||
|
# 1) handles multiple REQ clients (identities)
|
||||||
|
# 2) lets us send back one request → many replies (streamed events)
|
||||||
|
# 3) works in our non‑blocking poll loop alongside PUB
|
||||||
|
if self._replay_endpoint is not None:
|
||||||
|
self._replay = self._ctx.socket(zmq.ROUTER)
|
||||||
|
self._replay.bind(self._replay_endpoint)
|
||||||
|
|
||||||
|
def _publisher_thread(self) -> None:
|
||||||
|
"""Background thread that processes the event queue."""
|
||||||
|
self._pack = msgspec.msgpack.Encoder()
|
||||||
|
|
||||||
|
assert self._pub is not None # narrows type for mypy
|
||||||
|
|
||||||
|
while self._running or self._event_queue.qsize() > 0:
|
||||||
|
# --- replay (non-critical) ---------------------------------
|
||||||
|
if self._replay is not None and self._replay.poll(0):
|
||||||
|
try:
|
||||||
|
self._service_replay()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in replay: %s", e)
|
||||||
|
|
||||||
|
# --- main queue (critical) ---------------------------------
|
||||||
|
try:
|
||||||
|
event = self._event_queue.get(timeout=0.1)
|
||||||
|
if event is None:
|
||||||
|
break # Sentinel received, exit thread
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
seq = next(self._seq_gen)
|
||||||
|
|
||||||
|
payload = self._pack.encode(event)
|
||||||
|
seq_bytes = seq.to_bytes(8, "big")
|
||||||
|
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
|
||||||
|
|
||||||
|
self._buffer.append((seq, payload))
|
||||||
|
self._event_queue.task_done()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Publishing failed; back-off a bit to avoid a tight error loop
|
||||||
|
logger.exception("Error in publisher thread: %s", e)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def _service_replay(self) -> None:
|
||||||
|
"""If a replay request is waiting, send buffered batches."""
|
||||||
|
assert self._replay is not None # narrows type for mypy
|
||||||
|
|
||||||
|
frame = self._replay.recv_multipart()
|
||||||
|
if len(frame) != 3:
|
||||||
|
logger.warning("Invalid replay request: %s", frame)
|
||||||
|
return
|
||||||
|
client_id, _, start_seq_bytes = frame
|
||||||
|
start_seq = int.from_bytes(start_seq_bytes, "big")
|
||||||
|
|
||||||
|
for seq, buf in self._buffer:
|
||||||
|
if seq >= start_seq:
|
||||||
|
# [identity, empty_delim, seq_bytes, payload]
|
||||||
|
# (identity, empty_delim) are stripped off by the router
|
||||||
|
# receiving payload is (seq_bytes, payload)
|
||||||
|
self._replay.send_multipart(
|
||||||
|
(client_id, b"", seq.to_bytes(8, "big"), buf)
|
||||||
|
)
|
||||||
|
# Send end of sequence marker
|
||||||
|
# receiving payload is (-1, b""")
|
||||||
|
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||||
|
|
||||||
|
|
||||||
|
class KVEventsConfig(BaseModel):
|
||||||
|
"""Configuration for KV event publishing."""
|
||||||
|
|
||||||
|
publisher: str = "null"
|
||||||
|
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: str = "tcp://*:5557"
|
||||||
|
"""The zmq endpoint to use for publishing kv events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
replay_endpoint: Optional[str] = None
|
||||||
|
"""The zmq endpoint to use for replaying kv events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buffer_steps: int = 10_000
|
||||||
|
"""The number of steps to cache for replay endpoint. Will only save
|
||||||
|
events from the last N steps for the replay endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hwm: int = 100_000
|
||||||
|
"""The zmq high water mark for the event publisher. After queueing N events,
|
||||||
|
events will start dropping if the consumer is not keeping up.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_queue_size: int = 100_000
|
||||||
|
"""The maximum number of events to queue while waiting for publishing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
topic: str = ""
|
||||||
|
"""The topic to use for the event publisher. Consumers can subscribe to
|
||||||
|
this topic to receive events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
||||||
|
"""Parse the CLI value for the event publisher config."""
|
||||||
|
return KVEventsConfig.model_validate_json(cli_value)
|
||||||
|
|
||||||
|
|
||||||
|
class EventPublisherFactory:
|
||||||
|
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||||
|
"null": NullEventPublisher,
|
||||||
|
"zmq": ZmqEventPublisher,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
|
||||||
|
if name in cls._registry:
|
||||||
|
raise KeyError(f"publisher '{name}' already registered")
|
||||||
|
cls._registry[name] = ctor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: Optional[str]) -> EventPublisher:
|
||||||
|
"""Create publisher from a config mapping."""
|
||||||
|
if not config:
|
||||||
|
return NullEventPublisher()
|
||||||
|
config = KVEventsConfig.from_cli(config)
|
||||||
|
config_dict = config.model_dump()
|
||||||
|
|
||||||
|
kind = config_dict.pop("publisher", "null")
|
||||||
|
try:
|
||||||
|
constructor = cls._registry[kind]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||||
|
return constructor(**config_dict)
|
||||||
@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.decode import (
|
|||||||
DecodeTransferQueue,
|
DecodeTransferQueue,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
from sglang.srt.disaggregation.prefill import (
|
from sglang.srt.disaggregation.prefill import (
|
||||||
PrefillBootstrapQueue,
|
PrefillBootstrapQueue,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
@@ -197,6 +198,7 @@ class Scheduler(
|
|||||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
self.enable_kv_cache_events = server_args.kv_events_config is not None
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
@@ -204,7 +206,6 @@ class Scheduler(
|
|||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||||
self.page_size = server_args.page_size
|
self.page_size = server_args.page_size
|
||||||
|
|
||||||
# Distributed rank info
|
# Distributed rank info
|
||||||
self.dp_size = server_args.dp_size
|
self.dp_size = server_args.dp_size
|
||||||
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
||||||
@@ -422,6 +423,7 @@ class Scheduler(
|
|||||||
|
|
||||||
# Init metrics stats
|
# Init metrics stats
|
||||||
self.init_metrics()
|
self.init_metrics()
|
||||||
|
self.init_kv_events(server_args.kv_events_config)
|
||||||
|
|
||||||
# Init request dispatcher
|
# Init request dispatcher
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
@@ -515,6 +517,7 @@ class Scheduler(
|
|||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
disable=server_args.disable_radix_cache,
|
disable=server_args.disable_radix_cache,
|
||||||
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decode_mem_cache_buf_multiplier = (
|
self.decode_mem_cache_buf_multiplier = (
|
||||||
@@ -547,6 +550,10 @@ class Scheduler(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_kv_events(self, kv_events_config: Optional[str]):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
|
||||||
|
|
||||||
def init_disaggregation(self):
|
def init_disaggregation(self):
|
||||||
self.transfer_backend = TransferBackend(
|
self.transfer_backend = TransferBackend(
|
||||||
self.server_args.disaggregation_transfer_backend
|
self.server_args.disaggregation_transfer_backend
|
||||||
@@ -1154,6 +1161,7 @@ class Scheduler(
|
|||||||
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
||||||
|
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
self._publish_kv_events()
|
||||||
|
|
||||||
def log_decode_stats(
|
def log_decode_stats(
|
||||||
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
||||||
@@ -1213,6 +1221,7 @@ class Scheduler(
|
|||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
self.stats.spec_accept_length = spec_accept_length
|
self.stats.spec_accept_length = spec_accept_length
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
self._publish_kv_events()
|
||||||
|
|
||||||
def check_memory(self):
|
def check_memory(self):
|
||||||
available_size = (
|
available_size = (
|
||||||
@@ -1260,6 +1269,7 @@ class Scheduler(
|
|||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
self._publish_kv_events()
|
||||||
|
|
||||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||||
# Merge the prefill batch into the running batch
|
# Merge the prefill batch into the running batch
|
||||||
@@ -2194,6 +2204,13 @@ class Scheduler(
|
|||||||
prefix += f" PP{self.pp_rank}"
|
prefix += f" PP{self.pp_rank}"
|
||||||
return prefix
|
return prefix
|
||||||
|
|
||||||
|
def _publish_kv_events(self):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
events = self.tree_cache.take_events()
|
||||||
|
if events:
|
||||||
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
|
self.kv_event_publisher.publish(batch)
|
||||||
|
|
||||||
|
|
||||||
def is_health_check_generate_req(recv_req):
|
def is_health_check_generate_req(recv_req):
|
||||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
||||||
|
|||||||
@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
|
|||||||
|
|
||||||
def pretty_print(self):
|
def pretty_print(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def take_events(self):
|
||||||
|
return []
|
||||||
|
|||||||
@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.kv_events import (
|
||||||
|
AllBlocksCleared,
|
||||||
|
BlockRemoved,
|
||||||
|
BlockStored,
|
||||||
|
KVCacheEvent,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
|
|||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
disable: bool = False,
|
disable: bool = False,
|
||||||
|
enable_kv_cache_events: bool = False,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self.disable = disable
|
self.disable = disable
|
||||||
|
self.enable_kv_cache_events = enable_kv_cache_events
|
||||||
|
self.kv_event_queue = []
|
||||||
|
|
||||||
if self.token_to_kv_pool_allocator:
|
if self.token_to_kv_pool_allocator:
|
||||||
self.device = self.token_to_kv_pool_allocator.device
|
self.device = self.token_to_kv_pool_allocator.device
|
||||||
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.root_node.lock_ref = 1
|
self.root_node.lock_ref = 1
|
||||||
self.evictable_size_ = 0
|
self.evictable_size_ = 0
|
||||||
self.protected_size_ = 0
|
self.protected_size_ = 0
|
||||||
|
self._record_all_cleared_event()
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||||
"""Find the matching prefix from the radix tree.
|
"""Find the matching prefix from the radix tree.
|
||||||
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
|
|||||||
if len(x.parent.children) == 0:
|
if len(x.parent.children) == 0:
|
||||||
heapq.heappush(leaves, x.parent)
|
heapq.heappush(leaves, x.parent)
|
||||||
|
|
||||||
|
self._record_remove_event(x)
|
||||||
|
|
||||||
def inc_lock_ref(self, node: TreeNode):
|
def inc_lock_ref(self, node: TreeNode):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return 0
|
return 0
|
||||||
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||||
# new_node -> child
|
# new_node -> child
|
||||||
|
self._record_remove_event(child)
|
||||||
new_node = TreeNode()
|
new_node = TreeNode()
|
||||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||||
new_node.parent = child.parent
|
new_node.parent = child.parent
|
||||||
@@ -358,6 +371,10 @@ class RadixCache(BasePrefixCache):
|
|||||||
child.key = child.key[split_len:]
|
child.key = child.key[split_len:]
|
||||||
child.value = child.value[split_len:]
|
child.value = child.value[split_len:]
|
||||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||||
|
|
||||||
|
self._record_store_event(new_node)
|
||||||
|
self._record_store_event(child)
|
||||||
|
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||||
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
new_node.value = value
|
new_node.value = value
|
||||||
node.children[child_key] = new_node
|
node.children[child_key] = new_node
|
||||||
self.evictable_size_ += len(value)
|
self.evictable_size_ += len(value)
|
||||||
|
self._record_store_event(new_node)
|
||||||
return total_prefix_length
|
return total_prefix_length
|
||||||
|
|
||||||
def _print_helper(self, node: TreeNode, indent: int):
|
def _print_helper(self, node: TreeNode, indent: int):
|
||||||
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
return ret_list
|
return ret_list
|
||||||
|
|
||||||
|
def _record_store_event(self, node: TreeNode):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
block_hash = hash(tuple(node.key))
|
||||||
|
parent_block_hash = hash(tuple(node.parent.key))
|
||||||
|
self.kv_event_queue.append(
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[block_hash],
|
||||||
|
parent_block_hash=parent_block_hash,
|
||||||
|
token_ids=node.key,
|
||||||
|
block_size=len(node.key),
|
||||||
|
lora_id=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_remove_event(self, node: TreeNode):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
block_hash = hash(tuple(node.key))
|
||||||
|
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
||||||
|
|
||||||
|
def _record_all_cleared_event(self):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
self.kv_event_queue.append(AllBlocksCleared())
|
||||||
|
|
||||||
|
def take_events(self):
|
||||||
|
"""Atomically takes all events and clears the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of KV cache events.
|
||||||
|
"""
|
||||||
|
if not self.enable_kv_cache_events:
|
||||||
|
return []
|
||||||
|
events = self.kv_event_queue
|
||||||
|
self.kv_event_queue = []
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tree = RadixCache(None, None, page_size=1, disable=False)
|
tree = RadixCache(None, None, page_size=1, disable=False)
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class ServerArgs:
|
|||||||
collect_tokens_histogram: bool = False
|
collect_tokens_histogram: bool = False
|
||||||
decode_log_interval: int = 40
|
decode_log_interval: int = 40
|
||||||
enable_request_time_stats_logging: bool = False
|
enable_request_time_stats_logging: bool = False
|
||||||
|
kv_events_config: Optional[str] = None
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
@@ -814,6 +815,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.collect_tokens_histogram,
|
default=ServerArgs.collect_tokens_histogram,
|
||||||
help="Collect prompt/generation tokens histogram.",
|
help="Collect prompt/generation tokens histogram.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-events-config",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-log-interval",
|
"--decode-log-interval",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
247
test/srt/test_kv_events.py
Normal file
247
test/srt/test_kv_events.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import requests
|
||||||
|
import zmq
|
||||||
|
from msgspec.msgpack import Decoder
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.kv_events import (
|
||||||
|
AllBlocksCleared,
|
||||||
|
BlockRemoved,
|
||||||
|
BlockStored,
|
||||||
|
EventBatch,
|
||||||
|
KVCacheEvent,
|
||||||
|
KVEventBatch,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKvEvents(CustomTestCase):
|
||||||
|
def test_kv_events_enabled(self):
|
||||||
|
"""Test that kv events are sent and received by subscriber data when enabled"""
|
||||||
|
|
||||||
|
# Launch kv events subscriber
|
||||||
|
decoder = Decoder(type=KVEventBatch)
|
||||||
|
context = zmq.Context()
|
||||||
|
sub = context.socket(zmq.SUB)
|
||||||
|
sub.connect("tcp://localhost:5557")
|
||||||
|
topic = "kv-events"
|
||||||
|
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||||
|
|
||||||
|
# Launch sglang server
|
||||||
|
process = popen_launch_server(
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--kv-events-config",
|
||||||
|
'{"publisher": "zmq", "topic": "kv-events"}',
|
||||||
|
"--max-total-tokens",
|
||||||
|
32,
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
2,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make some requests to generate some metrics
|
||||||
|
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of Spain is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected events. These may be dependent on model used (meta-llama/Llama-3.2-1B-Instruct)
|
||||||
|
expected_events = [
|
||||||
|
# <begin> The capital city of France is
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-6650323075460941099],
|
||||||
|
parent_block_hash=5740354900026072187,
|
||||||
|
token_ids=[128000, 791, 6864, 3363, 315, 9822, 374],
|
||||||
|
block_size=7,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
# Paris. The Eiffel Tower
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-7584018293207282755],
|
||||||
|
parent_block_hash=-6650323075460941099,
|
||||||
|
token_ids=[12366, 13, 578, 469, 3168, 301, 22703],
|
||||||
|
block_size=7,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-8753497827991233192],
|
||||||
|
parent_block_hash=5740354900026072187,
|
||||||
|
token_ids=[0],
|
||||||
|
block_size=1,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
BlockRemoved(block_hashes=[-6650323075460941099]),
|
||||||
|
# <begin> The capital
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-2697055055087824455],
|
||||||
|
parent_block_hash=5740354900026072187,
|
||||||
|
token_ids=[128000, 791, 6864],
|
||||||
|
block_size=3,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
# city of France is
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-7505627135785778022],
|
||||||
|
parent_block_hash=-2697055055087824455,
|
||||||
|
token_ids=[3363, 315, 9822, 374],
|
||||||
|
block_size=4,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
# of France is
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-3861108700662737012],
|
||||||
|
parent_block_hash=-2697055055087824455,
|
||||||
|
token_ids=[315, 9822, 374],
|
||||||
|
block_size=3,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
BlockRemoved(block_hashes=[-7584018293207282755]),
|
||||||
|
BlockRemoved(block_hashes=[-8753497827991233192]),
|
||||||
|
BlockRemoved(block_hashes=[-7505627135785778022]),
|
||||||
|
# Paris. The Eiffel Tower is located in Paris. The Eiffel Tower is a famous landmark in Paris
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-3064341286825792715],
|
||||||
|
parent_block_hash=-3861108700662737012,
|
||||||
|
token_ids=[
|
||||||
|
12366,
|
||||||
|
13,
|
||||||
|
578,
|
||||||
|
469,
|
||||||
|
3168,
|
||||||
|
301,
|
||||||
|
22703,
|
||||||
|
374,
|
||||||
|
7559,
|
||||||
|
304,
|
||||||
|
12366,
|
||||||
|
13,
|
||||||
|
578,
|
||||||
|
469,
|
||||||
|
3168,
|
||||||
|
301,
|
||||||
|
22703,
|
||||||
|
374,
|
||||||
|
264,
|
||||||
|
11495,
|
||||||
|
38350,
|
||||||
|
304,
|
||||||
|
12366,
|
||||||
|
],
|
||||||
|
block_size=23,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
BlockRemoved(block_hashes=[-3861108700662737012]),
|
||||||
|
# of
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[6115672085296369592],
|
||||||
|
parent_block_hash=-2697055055087824455,
|
||||||
|
token_ids=[315],
|
||||||
|
block_size=1,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
# France is
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[4208810872343132234],
|
||||||
|
parent_block_hash=6115672085296369592,
|
||||||
|
token_ids=[9822, 374],
|
||||||
|
block_size=2,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
# Spain is
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[1675819893649989955],
|
||||||
|
parent_block_hash=6115672085296369592,
|
||||||
|
token_ids=[18157, 374],
|
||||||
|
block_size=2,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
BlockRemoved(block_hashes=[-3064341286825792715]),
|
||||||
|
# Madrid. The capital of France is Paris. The capital of Italy is Rome. The capital of Spain is Madrid.
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=[-8505834929190027295],
|
||||||
|
parent_block_hash=1675819893649989955,
|
||||||
|
token_ids=[
|
||||||
|
25048,
|
||||||
|
13,
|
||||||
|
578,
|
||||||
|
6864,
|
||||||
|
315,
|
||||||
|
9822,
|
||||||
|
374,
|
||||||
|
12366,
|
||||||
|
13,
|
||||||
|
578,
|
||||||
|
6864,
|
||||||
|
315,
|
||||||
|
15704,
|
||||||
|
374,
|
||||||
|
22463,
|
||||||
|
13,
|
||||||
|
578,
|
||||||
|
6864,
|
||||||
|
315,
|
||||||
|
18157,
|
||||||
|
374,
|
||||||
|
25048,
|
||||||
|
13,
|
||||||
|
],
|
||||||
|
block_size=23,
|
||||||
|
lora_id=None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get events
|
||||||
|
events = []
|
||||||
|
start = time.time()
|
||||||
|
max_wait_s = 5
|
||||||
|
while (
|
||||||
|
len(events) < len(expected_events)
|
||||||
|
and (time.time() - start) < max_wait_s
|
||||||
|
):
|
||||||
|
_, seq_bytes, payload = sub.recv_multipart()
|
||||||
|
event_batch = decoder.decode(payload)
|
||||||
|
for event in event_batch.events:
|
||||||
|
print(f" - {event}")
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
for expected in expected_events:
|
||||||
|
self.assertIn(expected, events)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user