diff --git a/python/sglang/srt/disaggregation/kv_events.py b/python/sglang/srt/disaggregation/kv_events.py index 092c6b063..f0a3a4357 100644 --- a/python/sglang/srt/disaggregation/kv_events.py +++ b/python/sglang/srt/disaggregation/kv_events.py @@ -43,6 +43,7 @@ class EventBatch( ): ts: float events: list[Any] + attn_dp_rank: Optional[int] = None class KVCacheEvent( @@ -76,7 +77,21 @@ class KVEventBatch(EventBatch): class EventPublisher(ABC): - """Lightweight publisher for EventBatch batches.""" + """ + Lightweight publisher for EventBatch batches with + support for DP attention. + + In DP attention - each rank has its own Scheduler and + KV cache instance in order to avoid duplicate events + and ensure proper event attribution. In our implementation + + - Each DP rank has its own EventPublisher + - Publishers annotate events with the dp rank + - This allows consumers to distinguish events from different DP ranks + """ + + def __init__(self, attn_dp_rank: int = 0): + self._attn_dp_rank = attn_dp_rank @abstractmethod def publish(self, events: EventBatch) -> None: @@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher): def __init__( self, + attn_dp_rank: int, endpoint: str = "tcp://*:5557", replay_endpoint: Optional[str] = None, buffer_steps: int = 10_000, @@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher): topic: str = "", ) -> None: # Storage + super().__init__(attn_dp_rank) self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) @@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher): self._ctx = zmq.Context.instance() self._pub: Optional[zmq.Socket] = None self._replay: Optional[zmq.Socket] = None - self._endpoint = endpoint - self._replay_endpoint = replay_endpoint + self._dp_rank = attn_dp_rank + self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) + self._replay_endpoint = self.offset_endpoint_port( + replay_endpoint, self._dp_rank + ) self._hwm = hwm self._socket_setup() @@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher): def publish(self, events: EventBatch) -> None: if not self._running: raise RuntimeError("Publisher is closed") + if events.attn_dp_rank is None: + events.attn_dp_rank = self._dp_rank self._event_queue.put(events) def shutdown(self) -> None: @@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher): # receiving payload is (-1, b""") self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) + @staticmethod + def offset_endpoint_port( + endpoint: Optional[str], data_parallel_rank: int + ) -> Optional[str]: + """Helper function to offset the port in an endpoint by + the data parallel rank. + + Args: + endpoint: The endpoint string + (e.g., "tcp://*:5557" or "inproc://cache") + data_parallel_rank: The data parallel rank to offset by + + Returns: + The endpoint with the port offset by data_parallel_rank + or suffix appended + """ + # Do nothing if input is None or data_parallel_rank is 0 + if not endpoint or data_parallel_rank == 0: + return endpoint + + if "inproc" in endpoint: + return f"{endpoint}_dp{data_parallel_rank}" + if "tcp" in endpoint: + if endpoint and ":" in endpoint: + # Get everything after the last colon (the port) + last_colon_idx = endpoint.rfind(":") + base_addr = endpoint[:last_colon_idx] + base_port = int(endpoint[last_colon_idx + 1 :]) + new_port = base_port + data_parallel_rank + return f"{base_addr}:{new_port}" + return endpoint + raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'") + class KVEventsConfig(BaseModel): """Configuration for KV event publishing.""" @@ -342,7 +397,7 @@ class EventPublisherFactory: cls._registry[name] = ctor @classmethod - def create(cls, config: Optional[str]) -> EventPublisher: + def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher: """Create publisher from a config mapping.""" if not config: return NullEventPublisher() @@ -354,4 +409,4 @@ class EventPublisherFactory: constructor = cls._registry[kind] except KeyError as exc: raise ValueError(f"Unknown event publisher '{kind}'") from exc - return constructor(**config_dict) + return constructor(attn_dp_rank=attn_dp_rank, **config_dict) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 382e200bd..5c2141d77 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -571,7 +571,9 @@ class Scheduler( def init_kv_events(self, kv_events_config: Optional[str]): if self.enable_kv_cache_events: - self.kv_event_publisher = EventPublisherFactory.create(kv_events_config) + self.kv_event_publisher = EventPublisherFactory.create( + kv_events_config, self.attn_dp_rank + ) def init_disaggregation(self): self.transfer_backend = TransferBackend( @@ -1988,7 +1990,7 @@ class Scheduler( self.cum_spec_accept_length = self.cum_spec_accept_count = 0 for k, v in server_args_dict.items(): global_server_args_dict[k] = v - logger.info(f"Global server args updated! " f"{global_server_args_dict=}") + logger.info(f"Global server args updated! {global_server_args_dict=}") return SetInternalStateReqOutput( updated=True, server_args=global_server_args_dict, diff --git a/test/srt/test_kv_events.py b/test/srt/test_kv_events.py index 928a13266..d333738c7 100644 --- a/test/srt/test_kv_events.py +++ b/test/srt/test_kv_events.py @@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase): 32, "--cuda-graph-max-bs", 2, + "--enable-dp-attention", + "--dp-size", + 1, ], ) @@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase): _, seq_bytes, payload = sub.recv_multipart() event_batch = decoder.decode(payload) for event in event_batch.events: - print(f" - {event}") events.append(event) for expected in expected_events: @@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase): finally: kill_process_tree(process.pid) + def test_kv_events_attn_dp(self): + """Test that kv events are properly tagged with DP rank in attention DP mode""" + + # Launch multiple subscribers for different DP ranks + decoder = Decoder(type=KVEventBatch) + context = zmq.Context() + + # Subscribe to both DP rank endpoints + sub_dp0 = context.socket(zmq.SUB) + sub_dp0.connect("tcp://localhost:5557") # DP rank 0 + topic = "kv-events" + sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic) + + sub_dp1 = context.socket(zmq.SUB) + sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank) + sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic) + + # Launch sglang server with DP attention enabled + process = popen_launch_server( + "silence09/DeepSeek-R1-Small-2layers", + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--kv-events-config", + '{"publisher": "zmq", "topic": "kv-events"}', + "--max-total-tokens", + 64, + "--cuda-graph-max-bs", + 4, + "--enable-dp-attention", + "--dp-size", + 2, + "--tp-size", + 2, + ], + ) + + try: + # Make requests to generate events + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") + self.assertEqual(response.status_code, 200) + + # Send multiple requests to trigger events from both DP ranks + for i in range(4): + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": f"Request {i}: The capital of country {i} is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + ) + + # Collect events from both DP ranks + events_dp0 = [] + events_dp1 = [] + start = time.time() + max_wait_s = 10 + min_events_per_rank = 3 # Expect at least a few events from each rank + + while (time.time() - start) < max_wait_s and ( + len(events_dp0) < min_events_per_rank + or len(events_dp1) < min_events_per_rank + ): + # Check DP rank 0 + if sub_dp0.poll(timeout=100): # 100ms timeout + _, seq_bytes, payload = sub_dp0.recv_multipart() + event_batch = decoder.decode(payload) + print( + f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}" + ) + self.assertEqual( + event_batch.attn_dp_rank, + 0, + "DP rank 0 events should have attn_dp_rank=0", + ) + for event in event_batch.events: + print(f" DP0 - {event}") + events_dp0.append(event) + + # Check DP rank 1 + if sub_dp1.poll(timeout=100): # 100ms timeout + _, seq_bytes, payload = sub_dp1.recv_multipart() + event_batch = decoder.decode(payload) + print( + f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}" + ) + self.assertEqual( + event_batch.attn_dp_rank, + 1, + "DP rank 1 events should have attn_dp_rank=1", + ) + for event in event_batch.events: + print(f" DP1 - {event}") + events_dp1.append(event) + + # Verify we got events from both DP ranks + print(f"Collected {len(events_dp0)} events from DP rank 0") + print(f"Collected {len(events_dp1)} events from DP rank 1") + + self.assertGreaterEqual( + len(events_dp0), + min_events_per_rank, + f"Expected at least {min_events_per_rank} events from DP rank 0", + ) + self.assertGreaterEqual( + len(events_dp1), + min_events_per_rank, + f"Expected at least {min_events_per_rank} events from DP rank 1", + ) + + # Verify event types are as expected + for events in [events_dp0, events_dp1]: + for event in events: + self.assertIsInstance( + event, + (BlockStored, BlockRemoved, AllBlocksCleared), + f"Event should be a KV cache event, got {type(event)}", + ) + + finally: + sub_dp0.close() + sub_dp1.close() + context.term() + kill_process_tree(process.pid) + if __name__ == "__main__": unittest.main()