diff --git a/python/pyproject.toml b/python/pyproject.toml index af02b51bb..627b1949c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -25,6 +25,7 @@ runtime_common = [ "interegular", "llguidance>=0.7.11,<0.8.0", "modelscope", + "msgspec", "ninja", "orjson", "packaging", diff --git a/python/sglang/srt/disaggregation/kv_events.py b/python/sglang/srt/disaggregation/kv_events.py new file mode 100644 index 000000000..092c6b063 --- /dev/null +++ b/python/sglang/srt/disaggregation/kv_events.py @@ -0,0 +1,357 @@ +""" +Copyright 2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +KV caching events +""" + +import atexit +import logging +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from itertools import count +from queue import Queue +from typing import Any, Callable, Optional, Union + +import msgspec +import zmq +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class EventBatch( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] +): + ts: float + events: list[Any] + + +class KVCacheEvent( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] + tag=True, +): + """Base class for all KV cache-related events""" + + +class BlockStored(KVCacheEvent): + block_hashes: list[int] + parent_block_hash: Optional[int] + token_ids: list[int] + block_size: int + lora_id: Optional[int] + + +class BlockRemoved(KVCacheEvent): + block_hashes: list[int] + + +class AllBlocksCleared(KVCacheEvent): + pass + + +class KVEventBatch(EventBatch): + events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + + +class EventPublisher(ABC): + """Lightweight publisher for EventBatch batches.""" + + @abstractmethod + def publish(self, events: EventBatch) -> None: + """Emit events in order. + + Implementations should guarantee at-least-once delivery and + monotonic ordering (e.g., via sequence numbers). + """ + + @abstractmethod + def shutdown(self) -> None: + """Shutdown the publisher.""" + + +class NullEventPublisher(EventPublisher): + """No-op implementation (default when disabled).""" + + def publish(self, events) -> None: + return + + def shutdown(self) -> None: + return + + +class ZmqEventPublisher(EventPublisher): + """Reliable PUB/ROUTER publisher with an in-memory replay buffer. + + Spawns a separate thread to handle publishing from a queue. + + Parameters + ---------- + endpoint: + PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to + connect. + replay_endpoint: + Optional ROUTER address for replay requests. When given, subscribers can + request missed batches by sending the starting sequence number as an + 8-byte big-endian integer. + buffer_steps: + Number of past batches to keep for replay. + hwm: + ZeroMQ high-water-mark for PUB socket. + max_queue_size: + Maximum number of events to buffer in memory. + topic: + Topic to publish events to. + """ + + SHUTDOWN_TIMEOUT: float = 1.0 + END_SEQ = (-1).to_bytes(8, "big", signed=True) + + def __init__( + self, + endpoint: str = "tcp://*:5557", + replay_endpoint: Optional[str] = None, + buffer_steps: int = 10_000, + hwm: int = 100_000, + max_queue_size: int = 100_000, + topic: str = "", + ) -> None: + # Storage + self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) + self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) + + # ZMQ sockets + self._ctx = zmq.Context.instance() + self._pub: Optional[zmq.Socket] = None + self._replay: Optional[zmq.Socket] = None + self._endpoint = endpoint + self._replay_endpoint = replay_endpoint + self._hwm = hwm + self._socket_setup() + + # Payload + self._seq_gen = count() + self._topic_bytes = topic.encode("utf-8") + + # Thread + self._running = True + logger.info("Starting ZMQ publisher thread") + + self._thread = threading.Thread( + target=self._publisher_thread, daemon=True, name="zmq-publisher" + ) + self._thread.start() + + atexit.register(self.shutdown) + + def publish(self, events: EventBatch) -> None: + if not self._running: + raise RuntimeError("Publisher is closed") + self._event_queue.put(events) + + def shutdown(self) -> None: + """Stop the publisher thread and clean up resources.""" + self._running = False + self._event_queue.put_nowait(None) + + start = time.time() + pending_items = True + while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT): + pending_items = not self._event_queue.empty() + if pending_items: + time.sleep(0.1) + + if pending_items: + logger.warning( + "Warning: Queue still has %s items after %s seconds timeout", + self._event_queue.qsize(), + self.SHUTDOWN_TIMEOUT, + ) + + if self._thread.is_alive(): + self._thread.join(timeout=self.SHUTDOWN_TIMEOUT) + + # Clean up ZMQ resources + try: + if self._pub is not None: + self._pub.close(linger=0) + if self._replay is not None: + self._replay.close(linger=0) + finally: + pass # Do not terminate context; other sockets may use it + + def _socket_setup(self) -> None: + """Initialize sockets + https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety + """ + if self._pub is None: + self._pub = self._ctx.socket(zmq.PUB) + self._pub.set_hwm(self._hwm) + # Heuristic: bind if wildcard / * present, else connect. + # bind stable, connect volatile convention + if ( + "*" in self._endpoint + or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://") + ): + self._pub.bind(self._endpoint) + else: + self._pub.connect(self._endpoint) + + # Set up replay socket: use ROUTER + # 1) handles multiple REQ clients (identities) + # 2) lets us send back one request → many replies (streamed events) + # 3) works in our non‑blocking poll loop alongside PUB + if self._replay_endpoint is not None: + self._replay = self._ctx.socket(zmq.ROUTER) + self._replay.bind(self._replay_endpoint) + + def _publisher_thread(self) -> None: + """Background thread that processes the event queue.""" + self._pack = msgspec.msgpack.Encoder() + + assert self._pub is not None # narrows type for mypy + + while self._running or self._event_queue.qsize() > 0: + # --- replay (non-critical) --------------------------------- + if self._replay is not None and self._replay.poll(0): + try: + self._service_replay() + except Exception as e: + logger.exception("Error in replay: %s", e) + + # --- main queue (critical) --------------------------------- + try: + event = self._event_queue.get(timeout=0.1) + if event is None: + break # Sentinel received, exit thread + except queue.Empty: + continue + + try: + seq = next(self._seq_gen) + + payload = self._pack.encode(event) + seq_bytes = seq.to_bytes(8, "big") + self._pub.send_multipart((self._topic_bytes, seq_bytes, payload)) + + self._buffer.append((seq, payload)) + self._event_queue.task_done() + + except Exception as e: + # Publishing failed; back-off a bit to avoid a tight error loop + logger.exception("Error in publisher thread: %s", e) + time.sleep(0.1) + + def _service_replay(self) -> None: + """If a replay request is waiting, send buffered batches.""" + assert self._replay is not None # narrows type for mypy + + frame = self._replay.recv_multipart() + if len(frame) != 3: + logger.warning("Invalid replay request: %s", frame) + return + client_id, _, start_seq_bytes = frame + start_seq = int.from_bytes(start_seq_bytes, "big") + + for seq, buf in self._buffer: + if seq >= start_seq: + # [identity, empty_delim, seq_bytes, payload] + # (identity, empty_delim) are stripped off by the router + # receiving payload is (seq_bytes, payload) + self._replay.send_multipart( + (client_id, b"", seq.to_bytes(8, "big"), buf) + ) + # Send end of sequence marker + # receiving payload is (-1, b""") + self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) + + +class KVEventsConfig(BaseModel): + """Configuration for KV event publishing.""" + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ + + @classmethod + def from_cli(cls, cli_value: str) -> "KVEventsConfig": + """Parse the CLI value for the event publisher config.""" + return KVEventsConfig.model_validate_json(cli_value) + + +class EventPublisherFactory: + _registry: dict[str, Callable[..., EventPublisher]] = { + "null": NullEventPublisher, + "zmq": ZmqEventPublisher, + } + + @classmethod + def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None: + if name in cls._registry: + raise KeyError(f"publisher '{name}' already registered") + cls._registry[name] = ctor + + @classmethod + def create(cls, config: Optional[str]) -> EventPublisher: + """Create publisher from a config mapping.""" + if not config: + return NullEventPublisher() + config = KVEventsConfig.from_cli(config) + config_dict = config.model_dump() + + kind = config_dict.pop("publisher", "null") + try: + constructor = cls._registry[kind] + except KeyError as exc: + raise ValueError(f"Unknown event publisher '{kind}'") from exc + return constructor(**config_dict) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 937b3552a..0506460b1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -41,6 +41,7 @@ from sglang.srt.disaggregation.decode import ( DecodeTransferQueue, SchedulerDisaggregationDecodeMixin, ) +from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch from sglang.srt.disaggregation.prefill import ( PrefillBootstrapQueue, SchedulerDisaggregationPrefillMixin, @@ -197,6 +198,7 @@ class Scheduler( self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics + self.enable_kv_cache_events = server_args.kv_events_config is not None self.stream_interval = server_args.stream_interval self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm @@ -204,7 +206,6 @@ class Scheduler( self.gpu_id = gpu_id self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.page_size = server_args.page_size - # Distributed rank info self.dp_size = server_args.dp_size self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( @@ -422,6 +423,7 @@ class Scheduler( # Init metrics stats self.init_metrics() + self.init_kv_events(server_args.kv_events_config) # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( @@ -515,6 +517,7 @@ class Scheduler( token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, page_size=self.page_size, disable=server_args.disable_radix_cache, + enable_kv_cache_events=self.enable_kv_cache_events, ) self.decode_mem_cache_buf_multiplier = ( @@ -547,6 +550,10 @@ class Scheduler( }, ) + def init_kv_events(self, kv_events_config: Optional[str]): + if self.enable_kv_cache_events: + self.kv_event_publisher = EventPublisherFactory.create(kv_events_config) + def init_disaggregation(self): self.transfer_backend = TransferBackend( self.server_args.disaggregation_transfer_backend @@ -1154,6 +1161,7 @@ class Scheduler( self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq self.metrics_collector.log_stats(self.stats) + self._publish_kv_events() def log_decode_stats( self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None @@ -1213,6 +1221,7 @@ class Scheduler( self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.spec_accept_length = spec_accept_length self.metrics_collector.log_stats(self.stats) + self._publish_kv_events() def check_memory(self): available_size = ( @@ -1260,6 +1269,7 @@ class Scheduler( self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.metrics_collector.log_stats(self.stats) + self._publish_kv_events() def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch @@ -2194,6 +2204,13 @@ class Scheduler( prefix += f" PP{self.pp_rank}" return prefix + def _publish_kv_events(self): + if self.enable_kv_cache_events: + events = self.tree_cache.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + def is_health_check_generate_req(recv_req): return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK") diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index f370346e1..2035bbdbf 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -48,3 +48,6 @@ class BasePrefixCache(ABC): def pretty_print(self): raise NotImplementedError() + + def take_events(self): + return [] diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index b1fd645be..bdcb7640f 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple import torch +from sglang.srt.disaggregation.kv_events import ( + AllBlocksCleared, + BlockRemoved, + BlockStored, + KVCacheEvent, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator @@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache): token_to_kv_pool_allocator: TokenToKVPoolAllocator, page_size: int, disable: bool = False, + enable_kv_cache_events: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.page_size = page_size self.disable = disable + self.enable_kv_cache_events = enable_kv_cache_events + self.kv_event_queue = [] if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device @@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache): self.root_node.lock_ref = 1 self.evictable_size_ = 0 self.protected_size_ = 0 + self._record_all_cleared_event() def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: """Find the matching prefix from the radix tree. @@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache): if len(x.parent.children) == 0: heapq.heappush(leaves, x.parent) + self._record_remove_event(x) + def inc_lock_ref(self, node: TreeNode): if self.disable: return 0 @@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache): def _split_node(self, key, child: TreeNode, split_len: int): # new_node -> child + self._record_remove_event(child) new_node = TreeNode() new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent @@ -358,6 +371,10 @@ class RadixCache(BasePrefixCache): child.key = child.key[split_len:] child.value = child.value[split_len:] new_node.parent.children[self.get_child_key_fn(key)] = new_node + + self._record_store_event(new_node) + self._record_store_event(child) + return new_node def _insert_helper(self, node: TreeNode, key: List, value): @@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache): new_node.value = value node.children[child_key] = new_node self.evictable_size_ += len(value) + self._record_store_event(new_node) return total_prefix_length def _print_helper(self, node: TreeNode, indent: int): @@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache): return ret_list + def _record_store_event(self, node: TreeNode): + if self.enable_kv_cache_events: + block_hash = hash(tuple(node.key)) + parent_block_hash = hash(tuple(node.parent.key)) + self.kv_event_queue.append( + BlockStored( + block_hashes=[block_hash], + parent_block_hash=parent_block_hash, + token_ids=node.key, + block_size=len(node.key), + lora_id=None, + ) + ) + + def _record_remove_event(self, node: TreeNode): + if self.enable_kv_cache_events: + block_hash = hash(tuple(node.key)) + self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash])) + + def _record_all_cleared_event(self): + if self.enable_kv_cache_events: + self.kv_event_queue.append(AllBlocksCleared()) + + def take_events(self): + """Atomically takes all events and clears the queue. + + Returns: + A list of KV cache events. + """ + if not self.enable_kv_cache_events: + return [] + events = self.kv_event_queue + self.kv_event_queue = [] + return events + if __name__ == "__main__": tree = RadixCache(None, None, page_size=1, disable=False) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 62c4b990f..1e650fe71 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -103,6 +103,7 @@ class ServerArgs: collect_tokens_histogram: bool = False decode_log_interval: int = 40 enable_request_time_stats_logging: bool = False + kv_events_config: Optional[str] = None # API related api_key: Optional[str] = None @@ -814,6 +815,12 @@ class ServerArgs: default=ServerArgs.collect_tokens_histogram, help="Collect prompt/generation tokens histogram.", ) + parser.add_argument( + "--kv-events-config", + type=str, + default=None, + help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.", + ) parser.add_argument( "--decode-log-interval", type=int, diff --git a/test/srt/test_kv_events.py b/test/srt/test_kv_events.py new file mode 100644 index 000000000..928a13266 --- /dev/null +++ b/test/srt/test_kv_events.py @@ -0,0 +1,247 @@ +import time +import unittest + +import msgspec +import requests +import zmq +from msgspec.msgpack import Decoder + +from sglang.srt.disaggregation.kv_events import ( + AllBlocksCleared, + BlockRemoved, + BlockStored, + EventBatch, + KVCacheEvent, + KVEventBatch, +) +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestKvEvents(CustomTestCase): + def test_kv_events_enabled(self): + """Test that kv events are sent and received by subscriber data when enabled""" + + # Launch kv events subscriber + decoder = Decoder(type=KVEventBatch) + context = zmq.Context() + sub = context.socket(zmq.SUB) + sub.connect("tcp://localhost:5557") + topic = "kv-events" + sub.setsockopt_string(zmq.SUBSCRIBE, topic) + + # Launch sglang server + process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--kv-events-config", + '{"publisher": "zmq", "topic": "kv-events"}', + "--max-total-tokens", + 32, + "--cuda-graph-max-bs", + 2, + ], + ) + + try: + # Make some requests to generate some metrics + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") + self.assertEqual(response.status_code, 200) + + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of Spain is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + + # Expected events. These may be dependent on model used (meta-llama/Llama-3.2-1B-Instruct) + expected_events = [ + # The capital city of France is + BlockStored( + block_hashes=[-6650323075460941099], + parent_block_hash=5740354900026072187, + token_ids=[128000, 791, 6864, 3363, 315, 9822, 374], + block_size=7, + lora_id=None, + ), + # Paris. The Eiffel Tower + BlockStored( + block_hashes=[-7584018293207282755], + parent_block_hash=-6650323075460941099, + token_ids=[12366, 13, 578, 469, 3168, 301, 22703], + block_size=7, + lora_id=None, + ), + BlockStored( + block_hashes=[-8753497827991233192], + parent_block_hash=5740354900026072187, + token_ids=[0], + block_size=1, + lora_id=None, + ), + BlockRemoved(block_hashes=[-6650323075460941099]), + # The capital + BlockStored( + block_hashes=[-2697055055087824455], + parent_block_hash=5740354900026072187, + token_ids=[128000, 791, 6864], + block_size=3, + lora_id=None, + ), + # city of France is + BlockStored( + block_hashes=[-7505627135785778022], + parent_block_hash=-2697055055087824455, + token_ids=[3363, 315, 9822, 374], + block_size=4, + lora_id=None, + ), + # of France is + BlockStored( + block_hashes=[-3861108700662737012], + parent_block_hash=-2697055055087824455, + token_ids=[315, 9822, 374], + block_size=3, + lora_id=None, + ), + BlockRemoved(block_hashes=[-7584018293207282755]), + BlockRemoved(block_hashes=[-8753497827991233192]), + BlockRemoved(block_hashes=[-7505627135785778022]), + # Paris. The Eiffel Tower is located in Paris. The Eiffel Tower is a famous landmark in Paris + BlockStored( + block_hashes=[-3064341286825792715], + parent_block_hash=-3861108700662737012, + token_ids=[ + 12366, + 13, + 578, + 469, + 3168, + 301, + 22703, + 374, + 7559, + 304, + 12366, + 13, + 578, + 469, + 3168, + 301, + 22703, + 374, + 264, + 11495, + 38350, + 304, + 12366, + ], + block_size=23, + lora_id=None, + ), + BlockRemoved(block_hashes=[-3861108700662737012]), + # of + BlockStored( + block_hashes=[6115672085296369592], + parent_block_hash=-2697055055087824455, + token_ids=[315], + block_size=1, + lora_id=None, + ), + # France is + BlockStored( + block_hashes=[4208810872343132234], + parent_block_hash=6115672085296369592, + token_ids=[9822, 374], + block_size=2, + lora_id=None, + ), + # Spain is + BlockStored( + block_hashes=[1675819893649989955], + parent_block_hash=6115672085296369592, + token_ids=[18157, 374], + block_size=2, + lora_id=None, + ), + BlockRemoved(block_hashes=[-3064341286825792715]), + # Madrid. The capital of France is Paris. The capital of Italy is Rome. The capital of Spain is Madrid. + BlockStored( + block_hashes=[-8505834929190027295], + parent_block_hash=1675819893649989955, + token_ids=[ + 25048, + 13, + 578, + 6864, + 315, + 9822, + 374, + 12366, + 13, + 578, + 6864, + 315, + 15704, + 374, + 22463, + 13, + 578, + 6864, + 315, + 18157, + 374, + 25048, + 13, + ], + block_size=23, + lora_id=None, + ), + ] + + # Get events + events = [] + start = time.time() + max_wait_s = 5 + while ( + len(events) < len(expected_events) + and (time.time() - start) < max_wait_s + ): + _, seq_bytes, payload = sub.recv_multipart() + event_batch = decoder.decode(payload) + for event in event_batch.events: + print(f" - {event}") + events.append(event) + + for expected in expected_events: + self.assertIn(expected, events) + + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main()