feat: add dp-rank to KV events (#6852)
This commit is contained in:
@@ -43,6 +43,7 @@ class EventBatch(
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
attn_dp_rank: Optional[int] = None
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch):
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches."""
|
||||
"""
|
||||
Lightweight publisher for EventBatch batches with
|
||||
support for DP attention.
|
||||
|
||||
In DP attention - each rank has its own Scheduler and
|
||||
KV cache instance in order to avoid duplicate events
|
||||
and ensure proper event attribution. In our implementation
|
||||
|
||||
- Each DP rank has its own EventPublisher
|
||||
- Publishers annotate events with the dp rank
|
||||
- This allows consumers to distinguish events from different DP ranks
|
||||
"""
|
||||
|
||||
def __init__(self, attn_dp_rank: int = 0):
|
||||
self._attn_dp_rank = attn_dp_rank
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_dp_rank: int,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: Optional[str] = None,
|
||||
buffer_steps: int = 10_000,
|
||||
@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher):
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
super().__init__(attn_dp_rank)
|
||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher):
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: Optional[zmq.Socket] = None
|
||||
self._replay: Optional[zmq.Socket] = None
|
||||
self._endpoint = endpoint
|
||||
self._replay_endpoint = replay_endpoint
|
||||
self._dp_rank = attn_dp_rank
|
||||
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||
self._replay_endpoint = self.offset_endpoint_port(
|
||||
replay_endpoint, self._dp_rank
|
||||
)
|
||||
self._hwm = hwm
|
||||
self._socket_setup()
|
||||
|
||||
@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher):
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
if events.attn_dp_rank is None:
|
||||
events.attn_dp_rank = self._dp_rank
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher):
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
@staticmethod
|
||||
def offset_endpoint_port(
|
||||
endpoint: Optional[str], data_parallel_rank: int
|
||||
) -> Optional[str]:
|
||||
"""Helper function to offset the port in an endpoint by
|
||||
the data parallel rank.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint string
|
||||
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||
data_parallel_rank: The data parallel rank to offset by
|
||||
|
||||
Returns:
|
||||
The endpoint with the port offset by data_parallel_rank
|
||||
or suffix appended
|
||||
"""
|
||||
# Do nothing if input is None or data_parallel_rank is 0
|
||||
if not endpoint or data_parallel_rank == 0:
|
||||
return endpoint
|
||||
|
||||
if "inproc" in endpoint:
|
||||
return f"{endpoint}_dp{data_parallel_rank}"
|
||||
if "tcp" in endpoint:
|
||||
if endpoint and ":" in endpoint:
|
||||
# Get everything after the last colon (the port)
|
||||
last_colon_idx = endpoint.rfind(":")
|
||||
base_addr = endpoint[:last_colon_idx]
|
||||
base_port = int(endpoint[last_colon_idx + 1 :])
|
||||
new_port = base_port + data_parallel_rank
|
||||
return f"{base_addr}:{new_port}"
|
||||
return endpoint
|
||||
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||
|
||||
|
||||
class KVEventsConfig(BaseModel):
|
||||
"""Configuration for KV event publishing."""
|
||||
@@ -342,7 +397,7 @@ class EventPublisherFactory:
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: Optional[str]) -> EventPublisher:
|
||||
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
@@ -354,4 +409,4 @@ class EventPublisherFactory:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(**config_dict)
|
||||
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)
|
||||
|
||||
@@ -571,7 +571,9 @@ class Scheduler(
|
||||
|
||||
def init_kv_events(self, kv_events_config: Optional[str]):
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
kv_events_config, self.attn_dp_rank
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
self.transfer_backend = TransferBackend(
|
||||
@@ -1988,7 +1990,7 @@ class Scheduler(
|
||||
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
||||
for k, v in server_args_dict.items():
|
||||
global_server_args_dict[k] = v
|
||||
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
|
||||
logger.info(f"Global server args updated! {global_server_args_dict=}")
|
||||
return SetInternalStateReqOutput(
|
||||
updated=True,
|
||||
server_args=global_server_args_dict,
|
||||
|
||||
@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase):
|
||||
32,
|
||||
"--cuda-graph-max-bs",
|
||||
2,
|
||||
"--enable-dp-attention",
|
||||
"--dp-size",
|
||||
1,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase):
|
||||
_, seq_bytes, payload = sub.recv_multipart()
|
||||
event_batch = decoder.decode(payload)
|
||||
for event in event_batch.events:
|
||||
print(f" - {event}")
|
||||
events.append(event)
|
||||
|
||||
for expected in expected_events:
|
||||
@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase):
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_kv_events_attn_dp(self):
|
||||
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
|
||||
|
||||
# Launch multiple subscribers for different DP ranks
|
||||
decoder = Decoder(type=KVEventBatch)
|
||||
context = zmq.Context()
|
||||
|
||||
# Subscribe to both DP rank endpoints
|
||||
sub_dp0 = context.socket(zmq.SUB)
|
||||
sub_dp0.connect("tcp://localhost:5557") # DP rank 0
|
||||
topic = "kv-events"
|
||||
sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||
|
||||
sub_dp1 = context.socket(zmq.SUB)
|
||||
sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank)
|
||||
sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||
|
||||
# Launch sglang server with DP attention enabled
|
||||
process = popen_launch_server(
|
||||
"silence09/DeepSeek-R1-Small-2layers",
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--kv-events-config",
|
||||
'{"publisher": "zmq", "topic": "kv-events"}',
|
||||
"--max-total-tokens",
|
||||
64,
|
||||
"--cuda-graph-max-bs",
|
||||
4,
|
||||
"--enable-dp-attention",
|
||||
"--dp-size",
|
||||
2,
|
||||
"--tp-size",
|
||||
2,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
# Make requests to generate events
|
||||
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Send multiple requests to trigger events from both DP ranks
|
||||
for i in range(4):
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||
json={
|
||||
"text": f"Request {i}: The capital of country {i} is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Collect events from both DP ranks
|
||||
events_dp0 = []
|
||||
events_dp1 = []
|
||||
start = time.time()
|
||||
max_wait_s = 10
|
||||
min_events_per_rank = 3 # Expect at least a few events from each rank
|
||||
|
||||
while (time.time() - start) < max_wait_s and (
|
||||
len(events_dp0) < min_events_per_rank
|
||||
or len(events_dp1) < min_events_per_rank
|
||||
):
|
||||
# Check DP rank 0
|
||||
if sub_dp0.poll(timeout=100): # 100ms timeout
|
||||
_, seq_bytes, payload = sub_dp0.recv_multipart()
|
||||
event_batch = decoder.decode(payload)
|
||||
print(
|
||||
f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
||||
)
|
||||
self.assertEqual(
|
||||
event_batch.attn_dp_rank,
|
||||
0,
|
||||
"DP rank 0 events should have attn_dp_rank=0",
|
||||
)
|
||||
for event in event_batch.events:
|
||||
print(f" DP0 - {event}")
|
||||
events_dp0.append(event)
|
||||
|
||||
# Check DP rank 1
|
||||
if sub_dp1.poll(timeout=100): # 100ms timeout
|
||||
_, seq_bytes, payload = sub_dp1.recv_multipart()
|
||||
event_batch = decoder.decode(payload)
|
||||
print(
|
||||
f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
||||
)
|
||||
self.assertEqual(
|
||||
event_batch.attn_dp_rank,
|
||||
1,
|
||||
"DP rank 1 events should have attn_dp_rank=1",
|
||||
)
|
||||
for event in event_batch.events:
|
||||
print(f" DP1 - {event}")
|
||||
events_dp1.append(event)
|
||||
|
||||
# Verify we got events from both DP ranks
|
||||
print(f"Collected {len(events_dp0)} events from DP rank 0")
|
||||
print(f"Collected {len(events_dp1)} events from DP rank 1")
|
||||
|
||||
self.assertGreaterEqual(
|
||||
len(events_dp0),
|
||||
min_events_per_rank,
|
||||
f"Expected at least {min_events_per_rank} events from DP rank 0",
|
||||
)
|
||||
self.assertGreaterEqual(
|
||||
len(events_dp1),
|
||||
min_events_per_rank,
|
||||
f"Expected at least {min_events_per_rank} events from DP rank 1",
|
||||
)
|
||||
|
||||
# Verify event types are as expected
|
||||
for events in [events_dp0, events_dp1]:
|
||||
for event in events:
|
||||
self.assertIsInstance(
|
||||
event,
|
||||
(BlockStored, BlockRemoved, AllBlocksCleared),
|
||||
f"Event should be a KV cache event, got {type(event)}",
|
||||
)
|
||||
|
||||
finally:
|
||||
sub_dp0.close()
|
||||
sub_dp1.close()
|
||||
context.term()
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user