feat: add dp-rank to KV events (#6852)
This commit is contained in:
@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase):
|
||||
32,
|
||||
"--cuda-graph-max-bs",
|
||||
2,
|
||||
"--enable-dp-attention",
|
||||
"--dp-size",
|
||||
1,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase):
|
||||
_, seq_bytes, payload = sub.recv_multipart()
|
||||
event_batch = decoder.decode(payload)
|
||||
for event in event_batch.events:
|
||||
print(f" - {event}")
|
||||
events.append(event)
|
||||
|
||||
for expected in expected_events:
|
||||
@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase):
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_kv_events_attn_dp(self):
|
||||
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
|
||||
|
||||
# Launch multiple subscribers for different DP ranks
|
||||
decoder = Decoder(type=KVEventBatch)
|
||||
context = zmq.Context()
|
||||
|
||||
# Subscribe to both DP rank endpoints
|
||||
sub_dp0 = context.socket(zmq.SUB)
|
||||
sub_dp0.connect("tcp://localhost:5557") # DP rank 0
|
||||
topic = "kv-events"
|
||||
sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||
|
||||
sub_dp1 = context.socket(zmq.SUB)
|
||||
sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank)
|
||||
sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||
|
||||
# Launch sglang server with DP attention enabled
|
||||
process = popen_launch_server(
|
||||
"silence09/DeepSeek-R1-Small-2layers",
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--kv-events-config",
|
||||
'{"publisher": "zmq", "topic": "kv-events"}',
|
||||
"--max-total-tokens",
|
||||
64,
|
||||
"--cuda-graph-max-bs",
|
||||
4,
|
||||
"--enable-dp-attention",
|
||||
"--dp-size",
|
||||
2,
|
||||
"--tp-size",
|
||||
2,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
# Make requests to generate events
|
||||
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Send multiple requests to trigger events from both DP ranks
|
||||
for i in range(4):
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||
json={
|
||||
"text": f"Request {i}: The capital of country {i} is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Collect events from both DP ranks
|
||||
events_dp0 = []
|
||||
events_dp1 = []
|
||||
start = time.time()
|
||||
max_wait_s = 10
|
||||
min_events_per_rank = 3 # Expect at least a few events from each rank
|
||||
|
||||
while (time.time() - start) < max_wait_s and (
|
||||
len(events_dp0) < min_events_per_rank
|
||||
or len(events_dp1) < min_events_per_rank
|
||||
):
|
||||
# Check DP rank 0
|
||||
if sub_dp0.poll(timeout=100): # 100ms timeout
|
||||
_, seq_bytes, payload = sub_dp0.recv_multipart()
|
||||
event_batch = decoder.decode(payload)
|
||||
print(
|
||||
f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
||||
)
|
||||
self.assertEqual(
|
||||
event_batch.attn_dp_rank,
|
||||
0,
|
||||
"DP rank 0 events should have attn_dp_rank=0",
|
||||
)
|
||||
for event in event_batch.events:
|
||||
print(f" DP0 - {event}")
|
||||
events_dp0.append(event)
|
||||
|
||||
# Check DP rank 1
|
||||
if sub_dp1.poll(timeout=100): # 100ms timeout
|
||||
_, seq_bytes, payload = sub_dp1.recv_multipart()
|
||||
event_batch = decoder.decode(payload)
|
||||
print(
|
||||
f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
||||
)
|
||||
self.assertEqual(
|
||||
event_batch.attn_dp_rank,
|
||||
1,
|
||||
"DP rank 1 events should have attn_dp_rank=1",
|
||||
)
|
||||
for event in event_batch.events:
|
||||
print(f" DP1 - {event}")
|
||||
events_dp1.append(event)
|
||||
|
||||
# Verify we got events from both DP ranks
|
||||
print(f"Collected {len(events_dp0)} events from DP rank 0")
|
||||
print(f"Collected {len(events_dp1)} events from DP rank 1")
|
||||
|
||||
self.assertGreaterEqual(
|
||||
len(events_dp0),
|
||||
min_events_per_rank,
|
||||
f"Expected at least {min_events_per_rank} events from DP rank 0",
|
||||
)
|
||||
self.assertGreaterEqual(
|
||||
len(events_dp1),
|
||||
min_events_per_rank,
|
||||
f"Expected at least {min_events_per_rank} events from DP rank 1",
|
||||
)
|
||||
|
||||
# Verify event types are as expected
|
||||
for events in [events_dp0, events_dp1]:
|
||||
for event in events:
|
||||
self.assertIsInstance(
|
||||
event,
|
||||
(BlockStored, BlockRemoved, AllBlocksCleared),
|
||||
f"Event should be a KV cache event, got {type(event)}",
|
||||
)
|
||||
|
||||
finally:
|
||||
sub_dp0.close()
|
||||
sub_dp1.close()
|
||||
context.term()
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user