feat: add dp-rank to KV events (#6852)

This commit is contained in:
ishandhanani
2025-06-04 15:29:34 -07:00
committed by GitHub
parent 3f1e433903
commit f0f84975f4
3 changed files with 195 additions and 8 deletions

View File

@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase):
32,
"--cuda-graph-max-bs",
2,
"--enable-dp-attention",
"--dp-size",
1,
],
)
@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase):
_, seq_bytes, payload = sub.recv_multipart()
event_batch = decoder.decode(payload)
for event in event_batch.events:
print(f" - {event}")
events.append(event)
for expected in expected_events:
@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase):
finally:
kill_process_tree(process.pid)
def test_kv_events_attn_dp(self):
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
# Launch multiple subscribers for different DP ranks
decoder = Decoder(type=KVEventBatch)
context = zmq.Context()
# Subscribe to both DP rank endpoints
sub_dp0 = context.socket(zmq.SUB)
sub_dp0.connect("tcp://localhost:5557") # DP rank 0
topic = "kv-events"
sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic)
sub_dp1 = context.socket(zmq.SUB)
sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank)
sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic)
# Launch sglang server with DP attention enabled
process = popen_launch_server(
"silence09/DeepSeek-R1-Small-2layers",
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--kv-events-config",
'{"publisher": "zmq", "topic": "kv-events"}',
"--max-total-tokens",
64,
"--cuda-graph-max-bs",
4,
"--enable-dp-attention",
"--dp-size",
2,
"--tp-size",
2,
],
)
try:
# Make requests to generate events
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
self.assertEqual(response.status_code, 200)
# Send multiple requests to trigger events from both DP ranks
for i in range(4):
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": f"Request {i}: The capital of country {i} is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
# Collect events from both DP ranks
events_dp0 = []
events_dp1 = []
start = time.time()
max_wait_s = 10
min_events_per_rank = 3 # Expect at least a few events from each rank
while (time.time() - start) < max_wait_s and (
len(events_dp0) < min_events_per_rank
or len(events_dp1) < min_events_per_rank
):
# Check DP rank 0
if sub_dp0.poll(timeout=100): # 100ms timeout
_, seq_bytes, payload = sub_dp0.recv_multipart()
event_batch = decoder.decode(payload)
print(
f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
)
self.assertEqual(
event_batch.attn_dp_rank,
0,
"DP rank 0 events should have attn_dp_rank=0",
)
for event in event_batch.events:
print(f" DP0 - {event}")
events_dp0.append(event)
# Check DP rank 1
if sub_dp1.poll(timeout=100): # 100ms timeout
_, seq_bytes, payload = sub_dp1.recv_multipart()
event_batch = decoder.decode(payload)
print(
f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
)
self.assertEqual(
event_batch.attn_dp_rank,
1,
"DP rank 1 events should have attn_dp_rank=1",
)
for event in event_batch.events:
print(f" DP1 - {event}")
events_dp1.append(event)
# Verify we got events from both DP ranks
print(f"Collected {len(events_dp0)} events from DP rank 0")
print(f"Collected {len(events_dp1)} events from DP rank 1")
self.assertGreaterEqual(
len(events_dp0),
min_events_per_rank,
f"Expected at least {min_events_per_rank} events from DP rank 0",
)
self.assertGreaterEqual(
len(events_dp1),
min_events_per_rank,
f"Expected at least {min_events_per_rank} events from DP rank 1",
)
# Verify event types are as expected
for events in [events_dp0, events_dp1]:
for event in events:
self.assertIsInstance(
event,
(BlockStored, BlockRemoved, AllBlocksCleared),
f"Event should be a KV cache event, got {type(event)}",
)
finally:
sub_dp0.close()
sub_dp1.close()
context.term()
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()