2025-05-19 14:19:54 -07:00
|
|
|
|
"""
|
|
|
|
|
|
Copyright 2025 SGLang Team
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
|
limitations under the License.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
KV caching events
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import atexit
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import queue
|
|
|
|
|
|
import threading
|
|
|
|
|
|
import time
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
from collections import deque
|
|
|
|
|
|
from itertools import count
|
|
|
|
|
|
from queue import Queue
|
|
|
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
import msgspec
|
|
|
|
|
|
import zmq
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EventBatch(
|
|
|
|
|
|
msgspec.Struct,
|
|
|
|
|
|
array_like=True, # type: ignore[call-arg]
|
|
|
|
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
|
|
|
|
gc=False, # type: ignore[call-arg]
|
|
|
|
|
|
):
|
|
|
|
|
|
ts: float
|
|
|
|
|
|
events: list[Any]
|
2025-06-04 15:29:34 -07:00
|
|
|
|
attn_dp_rank: Optional[int] = None
|
2025-05-19 14:19:54 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KVCacheEvent(
|
|
|
|
|
|
msgspec.Struct,
|
|
|
|
|
|
array_like=True, # type: ignore[call-arg]
|
|
|
|
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
|
|
|
|
gc=False, # type: ignore[call-arg]
|
|
|
|
|
|
tag=True,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""Base class for all KV cache-related events"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockStored(KVCacheEvent):
|
|
|
|
|
|
block_hashes: list[int]
|
|
|
|
|
|
parent_block_hash: Optional[int]
|
|
|
|
|
|
token_ids: list[int]
|
|
|
|
|
|
block_size: int
|
|
|
|
|
|
lora_id: Optional[int]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockRemoved(KVCacheEvent):
|
|
|
|
|
|
block_hashes: list[int]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllBlocksCleared(KVCacheEvent):
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KVEventBatch(EventBatch):
|
|
|
|
|
|
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EventPublisher(ABC):
|
2025-06-04 15:29:34 -07:00
|
|
|
|
"""
|
|
|
|
|
|
Lightweight publisher for EventBatch batches with
|
|
|
|
|
|
support for DP attention.
|
|
|
|
|
|
|
|
|
|
|
|
In DP attention - each rank has its own Scheduler and
|
|
|
|
|
|
KV cache instance in order to avoid duplicate events
|
|
|
|
|
|
and ensure proper event attribution. In our implementation
|
|
|
|
|
|
|
|
|
|
|
|
- Each DP rank has its own EventPublisher
|
|
|
|
|
|
- Publishers annotate events with the dp rank
|
|
|
|
|
|
- This allows consumers to distinguish events from different DP ranks
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, attn_dp_rank: int = 0):
|
|
|
|
|
|
self._attn_dp_rank = attn_dp_rank
|
2025-05-19 14:19:54 -07:00
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def publish(self, events: EventBatch) -> None:
|
|
|
|
|
|
"""Emit events in order.
|
|
|
|
|
|
|
|
|
|
|
|
Implementations should guarantee at-least-once delivery and
|
|
|
|
|
|
monotonic ordering (e.g., via sequence numbers).
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def shutdown(self) -> None:
|
|
|
|
|
|
"""Shutdown the publisher."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NullEventPublisher(EventPublisher):
|
|
|
|
|
|
"""No-op implementation (default when disabled)."""
|
|
|
|
|
|
|
|
|
|
|
|
def publish(self, events) -> None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
def shutdown(self) -> None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZmqEventPublisher(EventPublisher):
|
|
|
|
|
|
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
|
|
|
|
|
|
|
|
|
|
|
Spawns a separate thread to handle publishing from a queue.
|
|
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
|
----------
|
|
|
|
|
|
endpoint:
|
|
|
|
|
|
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
|
|
|
|
|
connect.
|
|
|
|
|
|
replay_endpoint:
|
|
|
|
|
|
Optional ROUTER address for replay requests. When given, subscribers can
|
|
|
|
|
|
request missed batches by sending the starting sequence number as an
|
|
|
|
|
|
8-byte big-endian integer.
|
|
|
|
|
|
buffer_steps:
|
|
|
|
|
|
Number of past batches to keep for replay.
|
|
|
|
|
|
hwm:
|
|
|
|
|
|
ZeroMQ high-water-mark for PUB socket.
|
|
|
|
|
|
max_queue_size:
|
|
|
|
|
|
Maximum number of events to buffer in memory.
|
|
|
|
|
|
topic:
|
|
|
|
|
|
Topic to publish events to.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
SHUTDOWN_TIMEOUT: float = 1.0
|
|
|
|
|
|
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
2025-06-04 15:29:34 -07:00
|
|
|
|
attn_dp_rank: int,
|
2025-05-19 14:19:54 -07:00
|
|
|
|
endpoint: str = "tcp://*:5557",
|
|
|
|
|
|
replay_endpoint: Optional[str] = None,
|
|
|
|
|
|
buffer_steps: int = 10_000,
|
|
|
|
|
|
hwm: int = 100_000,
|
|
|
|
|
|
max_queue_size: int = 100_000,
|
|
|
|
|
|
topic: str = "",
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
# Storage
|
2025-06-04 15:29:34 -07:00
|
|
|
|
super().__init__(attn_dp_rank)
|
2025-05-19 14:19:54 -07:00
|
|
|
|
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
|
|
|
|
|
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
|
|
|
|
|
|
|
|
|
|
|
# ZMQ sockets
|
|
|
|
|
|
self._ctx = zmq.Context.instance()
|
|
|
|
|
|
self._pub: Optional[zmq.Socket] = None
|
|
|
|
|
|
self._replay: Optional[zmq.Socket] = None
|
2025-06-04 15:29:34 -07:00
|
|
|
|
self._dp_rank = attn_dp_rank
|
|
|
|
|
|
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
|
|
|
|
|
self._replay_endpoint = self.offset_endpoint_port(
|
|
|
|
|
|
replay_endpoint, self._dp_rank
|
|
|
|
|
|
)
|
2025-05-19 14:19:54 -07:00
|
|
|
|
self._hwm = hwm
|
|
|
|
|
|
self._socket_setup()
|
|
|
|
|
|
|
|
|
|
|
|
# Payload
|
|
|
|
|
|
self._seq_gen = count()
|
|
|
|
|
|
self._topic_bytes = topic.encode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
# Thread
|
|
|
|
|
|
self._running = True
|
|
|
|
|
|
logger.info("Starting ZMQ publisher thread")
|
|
|
|
|
|
|
|
|
|
|
|
self._thread = threading.Thread(
|
|
|
|
|
|
target=self._publisher_thread, daemon=True, name="zmq-publisher"
|
|
|
|
|
|
)
|
|
|
|
|
|
self._thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
atexit.register(self.shutdown)
|
|
|
|
|
|
|
|
|
|
|
|
def publish(self, events: EventBatch) -> None:
|
|
|
|
|
|
if not self._running:
|
|
|
|
|
|
raise RuntimeError("Publisher is closed")
|
2025-06-04 15:29:34 -07:00
|
|
|
|
if events.attn_dp_rank is None:
|
|
|
|
|
|
events.attn_dp_rank = self._dp_rank
|
2025-05-19 14:19:54 -07:00
|
|
|
|
self._event_queue.put(events)
|
|
|
|
|
|
|
|
|
|
|
|
def shutdown(self) -> None:
|
|
|
|
|
|
"""Stop the publisher thread and clean up resources."""
|
|
|
|
|
|
self._running = False
|
|
|
|
|
|
self._event_queue.put_nowait(None)
|
|
|
|
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
pending_items = True
|
|
|
|
|
|
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
|
|
|
|
|
pending_items = not self._event_queue.empty()
|
|
|
|
|
|
if pending_items:
|
|
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
if pending_items:
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
|
"Warning: Queue still has %s items after %s seconds timeout",
|
|
|
|
|
|
self._event_queue.qsize(),
|
|
|
|
|
|
self.SHUTDOWN_TIMEOUT,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if self._thread.is_alive():
|
|
|
|
|
|
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
|
|
|
|
|
|
|
|
|
|
|
# Clean up ZMQ resources
|
|
|
|
|
|
try:
|
|
|
|
|
|
if self._pub is not None:
|
|
|
|
|
|
self._pub.close(linger=0)
|
|
|
|
|
|
if self._replay is not None:
|
|
|
|
|
|
self._replay.close(linger=0)
|
|
|
|
|
|
finally:
|
|
|
|
|
|
pass # Do not terminate context; other sockets may use it
|
|
|
|
|
|
|
|
|
|
|
|
def _socket_setup(self) -> None:
|
|
|
|
|
|
"""Initialize sockets
|
|
|
|
|
|
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self._pub is None:
|
|
|
|
|
|
self._pub = self._ctx.socket(zmq.PUB)
|
|
|
|
|
|
self._pub.set_hwm(self._hwm)
|
|
|
|
|
|
# Heuristic: bind if wildcard / * present, else connect.
|
|
|
|
|
|
# bind stable, connect volatile convention
|
|
|
|
|
|
if (
|
|
|
|
|
|
"*" in self._endpoint
|
|
|
|
|
|
or "::" in self._endpoint
|
|
|
|
|
|
or self._endpoint.startswith("ipc://")
|
|
|
|
|
|
or self._endpoint.startswith("inproc://")
|
|
|
|
|
|
):
|
|
|
|
|
|
self._pub.bind(self._endpoint)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self._pub.connect(self._endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# Set up replay socket: use ROUTER
|
|
|
|
|
|
# 1) handles multiple REQ clients (identities)
|
|
|
|
|
|
# 2) lets us send back one request → many replies (streamed events)
|
|
|
|
|
|
# 3) works in our non‑blocking poll loop alongside PUB
|
|
|
|
|
|
if self._replay_endpoint is not None:
|
|
|
|
|
|
self._replay = self._ctx.socket(zmq.ROUTER)
|
|
|
|
|
|
self._replay.bind(self._replay_endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
def _publisher_thread(self) -> None:
|
|
|
|
|
|
"""Background thread that processes the event queue."""
|
|
|
|
|
|
self._pack = msgspec.msgpack.Encoder()
|
|
|
|
|
|
|
|
|
|
|
|
assert self._pub is not None # narrows type for mypy
|
|
|
|
|
|
|
|
|
|
|
|
while self._running or self._event_queue.qsize() > 0:
|
|
|
|
|
|
# --- replay (non-critical) ---------------------------------
|
|
|
|
|
|
if self._replay is not None and self._replay.poll(0):
|
|
|
|
|
|
try:
|
|
|
|
|
|
self._service_replay()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.exception("Error in replay: %s", e)
|
|
|
|
|
|
|
|
|
|
|
|
# --- main queue (critical) ---------------------------------
|
|
|
|
|
|
try:
|
|
|
|
|
|
event = self._event_queue.get(timeout=0.1)
|
|
|
|
|
|
if event is None:
|
|
|
|
|
|
break # Sentinel received, exit thread
|
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
seq = next(self._seq_gen)
|
|
|
|
|
|
|
|
|
|
|
|
payload = self._pack.encode(event)
|
|
|
|
|
|
seq_bytes = seq.to_bytes(8, "big")
|
|
|
|
|
|
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
|
|
|
|
|
|
|
|
|
|
|
|
self._buffer.append((seq, payload))
|
|
|
|
|
|
self._event_queue.task_done()
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Publishing failed; back-off a bit to avoid a tight error loop
|
|
|
|
|
|
logger.exception("Error in publisher thread: %s", e)
|
|
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
def _service_replay(self) -> None:
|
|
|
|
|
|
"""If a replay request is waiting, send buffered batches."""
|
|
|
|
|
|
assert self._replay is not None # narrows type for mypy
|
|
|
|
|
|
|
|
|
|
|
|
frame = self._replay.recv_multipart()
|
|
|
|
|
|
if len(frame) != 3:
|
|
|
|
|
|
logger.warning("Invalid replay request: %s", frame)
|
|
|
|
|
|
return
|
|
|
|
|
|
client_id, _, start_seq_bytes = frame
|
|
|
|
|
|
start_seq = int.from_bytes(start_seq_bytes, "big")
|
|
|
|
|
|
|
|
|
|
|
|
for seq, buf in self._buffer:
|
|
|
|
|
|
if seq >= start_seq:
|
|
|
|
|
|
# [identity, empty_delim, seq_bytes, payload]
|
|
|
|
|
|
# (identity, empty_delim) are stripped off by the router
|
|
|
|
|
|
# receiving payload is (seq_bytes, payload)
|
|
|
|
|
|
self._replay.send_multipart(
|
|
|
|
|
|
(client_id, b"", seq.to_bytes(8, "big"), buf)
|
|
|
|
|
|
)
|
|
|
|
|
|
# Send end of sequence marker
|
|
|
|
|
|
# receiving payload is (-1, b""")
|
|
|
|
|
|
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
|
|
|
|
|
|
2025-06-04 15:29:34 -07:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def offset_endpoint_port(
|
|
|
|
|
|
endpoint: Optional[str], data_parallel_rank: int
|
|
|
|
|
|
) -> Optional[str]:
|
|
|
|
|
|
"""Helper function to offset the port in an endpoint by
|
|
|
|
|
|
the data parallel rank.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
endpoint: The endpoint string
|
|
|
|
|
|
(e.g., "tcp://*:5557" or "inproc://cache")
|
|
|
|
|
|
data_parallel_rank: The data parallel rank to offset by
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
The endpoint with the port offset by data_parallel_rank
|
|
|
|
|
|
or suffix appended
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Do nothing if input is None or data_parallel_rank is 0
|
|
|
|
|
|
if not endpoint or data_parallel_rank == 0:
|
|
|
|
|
|
return endpoint
|
|
|
|
|
|
|
|
|
|
|
|
if "inproc" in endpoint:
|
|
|
|
|
|
return f"{endpoint}_dp{data_parallel_rank}"
|
|
|
|
|
|
if "tcp" in endpoint:
|
|
|
|
|
|
if endpoint and ":" in endpoint:
|
|
|
|
|
|
# Get everything after the last colon (the port)
|
|
|
|
|
|
last_colon_idx = endpoint.rfind(":")
|
|
|
|
|
|
base_addr = endpoint[:last_colon_idx]
|
|
|
|
|
|
base_port = int(endpoint[last_colon_idx + 1 :])
|
|
|
|
|
|
new_port = base_port + data_parallel_rank
|
|
|
|
|
|
return f"{base_addr}:{new_port}"
|
|
|
|
|
|
return endpoint
|
|
|
|
|
|
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
|
|
|
|
|
|
2025-05-19 14:19:54 -07:00
|
|
|
|
|
|
|
|
|
|
class KVEventsConfig(BaseModel):
|
|
|
|
|
|
"""Configuration for KV event publishing."""
|
|
|
|
|
|
|
|
|
|
|
|
publisher: str = "null"
|
|
|
|
|
|
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
endpoint: str = "tcp://*:5557"
|
|
|
|
|
|
"""The zmq endpoint to use for publishing kv events.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
replay_endpoint: Optional[str] = None
|
|
|
|
|
|
"""The zmq endpoint to use for replaying kv events.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
buffer_steps: int = 10_000
|
|
|
|
|
|
"""The number of steps to cache for replay endpoint. Will only save
|
|
|
|
|
|
events from the last N steps for the replay endpoint.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
hwm: int = 100_000
|
|
|
|
|
|
"""The zmq high water mark for the event publisher. After queueing N events,
|
|
|
|
|
|
events will start dropping if the consumer is not keeping up.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
max_queue_size: int = 100_000
|
|
|
|
|
|
"""The maximum number of events to queue while waiting for publishing.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
topic: str = ""
|
|
|
|
|
|
"""The topic to use for the event publisher. Consumers can subscribe to
|
|
|
|
|
|
this topic to receive events.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
|
|
|
|
|
"""Parse the CLI value for the event publisher config."""
|
|
|
|
|
|
return KVEventsConfig.model_validate_json(cli_value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EventPublisherFactory:
|
|
|
|
|
|
_registry: dict[str, Callable[..., EventPublisher]] = {
|
|
|
|
|
|
"null": NullEventPublisher,
|
|
|
|
|
|
"zmq": ZmqEventPublisher,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
|
|
|
|
|
|
if name in cls._registry:
|
|
|
|
|
|
raise KeyError(f"publisher '{name}' already registered")
|
|
|
|
|
|
cls._registry[name] = ctor
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2025-06-04 15:29:34 -07:00
|
|
|
|
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
|
2025-05-19 14:19:54 -07:00
|
|
|
|
"""Create publisher from a config mapping."""
|
|
|
|
|
|
if not config:
|
|
|
|
|
|
return NullEventPublisher()
|
|
|
|
|
|
config = KVEventsConfig.from_cli(config)
|
|
|
|
|
|
config_dict = config.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
kind = config_dict.pop("publisher", "null")
|
|
|
|
|
|
try:
|
|
|
|
|
|
constructor = cls._registry[kind]
|
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
|
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
2025-06-04 15:29:34 -07:00
|
|
|
|
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)
|