feat: add dp-rank to KV events (#6852)
This commit is contained in:
@@ -43,6 +43,7 @@ class EventBatch(
|
|||||||
):
|
):
|
||||||
ts: float
|
ts: float
|
||||||
events: list[Any]
|
events: list[Any]
|
||||||
|
attn_dp_rank: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class KVCacheEvent(
|
class KVCacheEvent(
|
||||||
@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch):
|
|||||||
|
|
||||||
|
|
||||||
class EventPublisher(ABC):
|
class EventPublisher(ABC):
|
||||||
"""Lightweight publisher for EventBatch batches."""
|
"""
|
||||||
|
Lightweight publisher for EventBatch batches with
|
||||||
|
support for DP attention.
|
||||||
|
|
||||||
|
In DP attention - each rank has its own Scheduler and
|
||||||
|
KV cache instance in order to avoid duplicate events
|
||||||
|
and ensure proper event attribution. In our implementation
|
||||||
|
|
||||||
|
- Each DP rank has its own EventPublisher
|
||||||
|
- Publishers annotate events with the dp rank
|
||||||
|
- This allows consumers to distinguish events from different DP ranks
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, attn_dp_rank: int = 0):
|
||||||
|
self._attn_dp_rank = attn_dp_rank
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def publish(self, events: EventBatch) -> None:
|
def publish(self, events: EventBatch) -> None:
|
||||||
@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
attn_dp_rank: int,
|
||||||
endpoint: str = "tcp://*:5557",
|
endpoint: str = "tcp://*:5557",
|
||||||
replay_endpoint: Optional[str] = None,
|
replay_endpoint: Optional[str] = None,
|
||||||
buffer_steps: int = 10_000,
|
buffer_steps: int = 10_000,
|
||||||
@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
topic: str = "",
|
topic: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
# Storage
|
# Storage
|
||||||
|
super().__init__(attn_dp_rank)
|
||||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||||
|
|
||||||
@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
self._ctx = zmq.Context.instance()
|
self._ctx = zmq.Context.instance()
|
||||||
self._pub: Optional[zmq.Socket] = None
|
self._pub: Optional[zmq.Socket] = None
|
||||||
self._replay: Optional[zmq.Socket] = None
|
self._replay: Optional[zmq.Socket] = None
|
||||||
self._endpoint = endpoint
|
self._dp_rank = attn_dp_rank
|
||||||
self._replay_endpoint = replay_endpoint
|
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||||
|
self._replay_endpoint = self.offset_endpoint_port(
|
||||||
|
replay_endpoint, self._dp_rank
|
||||||
|
)
|
||||||
self._hwm = hwm
|
self._hwm = hwm
|
||||||
self._socket_setup()
|
self._socket_setup()
|
||||||
|
|
||||||
@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
def publish(self, events: EventBatch) -> None:
|
def publish(self, events: EventBatch) -> None:
|
||||||
if not self._running:
|
if not self._running:
|
||||||
raise RuntimeError("Publisher is closed")
|
raise RuntimeError("Publisher is closed")
|
||||||
|
if events.attn_dp_rank is None:
|
||||||
|
events.attn_dp_rank = self._dp_rank
|
||||||
self._event_queue.put(events)
|
self._event_queue.put(events)
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
# receiving payload is (-1, b""")
|
# receiving payload is (-1, b""")
|
||||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def offset_endpoint_port(
|
||||||
|
endpoint: Optional[str], data_parallel_rank: int
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Helper function to offset the port in an endpoint by
|
||||||
|
the data parallel rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: The endpoint string
|
||||||
|
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||||
|
data_parallel_rank: The data parallel rank to offset by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The endpoint with the port offset by data_parallel_rank
|
||||||
|
or suffix appended
|
||||||
|
"""
|
||||||
|
# Do nothing if input is None or data_parallel_rank is 0
|
||||||
|
if not endpoint or data_parallel_rank == 0:
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
if "inproc" in endpoint:
|
||||||
|
return f"{endpoint}_dp{data_parallel_rank}"
|
||||||
|
if "tcp" in endpoint:
|
||||||
|
if endpoint and ":" in endpoint:
|
||||||
|
# Get everything after the last colon (the port)
|
||||||
|
last_colon_idx = endpoint.rfind(":")
|
||||||
|
base_addr = endpoint[:last_colon_idx]
|
||||||
|
base_port = int(endpoint[last_colon_idx + 1 :])
|
||||||
|
new_port = base_port + data_parallel_rank
|
||||||
|
return f"{base_addr}:{new_port}"
|
||||||
|
return endpoint
|
||||||
|
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||||
|
|
||||||
|
|
||||||
class KVEventsConfig(BaseModel):
|
class KVEventsConfig(BaseModel):
|
||||||
"""Configuration for KV event publishing."""
|
"""Configuration for KV event publishing."""
|
||||||
@@ -342,7 +397,7 @@ class EventPublisherFactory:
|
|||||||
cls._registry[name] = ctor
|
cls._registry[name] = ctor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, config: Optional[str]) -> EventPublisher:
|
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
|
||||||
"""Create publisher from a config mapping."""
|
"""Create publisher from a config mapping."""
|
||||||
if not config:
|
if not config:
|
||||||
return NullEventPublisher()
|
return NullEventPublisher()
|
||||||
@@ -354,4 +409,4 @@ class EventPublisherFactory:
|
|||||||
constructor = cls._registry[kind]
|
constructor = cls._registry[kind]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||||
return constructor(**config_dict)
|
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)
|
||||||
|
|||||||
@@ -571,7 +571,9 @@ class Scheduler(
|
|||||||
|
|
||||||
def init_kv_events(self, kv_events_config: Optional[str]):
|
def init_kv_events(self, kv_events_config: Optional[str]):
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
|
self.kv_event_publisher = EventPublisherFactory.create(
|
||||||
|
kv_events_config, self.attn_dp_rank
|
||||||
|
)
|
||||||
|
|
||||||
def init_disaggregation(self):
|
def init_disaggregation(self):
|
||||||
self.transfer_backend = TransferBackend(
|
self.transfer_backend = TransferBackend(
|
||||||
@@ -1988,7 +1990,7 @@ class Scheduler(
|
|||||||
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
||||||
for k, v in server_args_dict.items():
|
for k, v in server_args_dict.items():
|
||||||
global_server_args_dict[k] = v
|
global_server_args_dict[k] = v
|
||||||
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
|
logger.info(f"Global server args updated! {global_server_args_dict=}")
|
||||||
return SetInternalStateReqOutput(
|
return SetInternalStateReqOutput(
|
||||||
updated=True,
|
updated=True,
|
||||||
server_args=global_server_args_dict,
|
server_args=global_server_args_dict,
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase):
|
|||||||
32,
|
32,
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
2,
|
2,
|
||||||
|
"--enable-dp-attention",
|
||||||
|
"--dp-size",
|
||||||
|
1,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase):
|
|||||||
_, seq_bytes, payload = sub.recv_multipart()
|
_, seq_bytes, payload = sub.recv_multipart()
|
||||||
event_batch = decoder.decode(payload)
|
event_batch = decoder.decode(payload)
|
||||||
for event in event_batch.events:
|
for event in event_batch.events:
|
||||||
print(f" - {event}")
|
|
||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
for expected in expected_events:
|
for expected in expected_events:
|
||||||
@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase):
|
|||||||
finally:
|
finally:
|
||||||
kill_process_tree(process.pid)
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_kv_events_attn_dp(self):
|
||||||
|
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
|
||||||
|
|
||||||
|
# Launch multiple subscribers for different DP ranks
|
||||||
|
decoder = Decoder(type=KVEventBatch)
|
||||||
|
context = zmq.Context()
|
||||||
|
|
||||||
|
# Subscribe to both DP rank endpoints
|
||||||
|
sub_dp0 = context.socket(zmq.SUB)
|
||||||
|
sub_dp0.connect("tcp://localhost:5557") # DP rank 0
|
||||||
|
topic = "kv-events"
|
||||||
|
sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||||
|
|
||||||
|
sub_dp1 = context.socket(zmq.SUB)
|
||||||
|
sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank)
|
||||||
|
sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||||
|
|
||||||
|
# Launch sglang server with DP attention enabled
|
||||||
|
process = popen_launch_server(
|
||||||
|
"silence09/DeepSeek-R1-Small-2layers",
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--kv-events-config",
|
||||||
|
'{"publisher": "zmq", "topic": "kv-events"}',
|
||||||
|
"--max-total-tokens",
|
||||||
|
64,
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
4,
|
||||||
|
"--enable-dp-attention",
|
||||||
|
"--dp-size",
|
||||||
|
2,
|
||||||
|
"--tp-size",
|
||||||
|
2,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make requests to generate events
|
||||||
|
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Send multiple requests to trigger events from both DP ranks
|
||||||
|
for i in range(4):
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": f"Request {i}: The capital of country {i} is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect events from both DP ranks
|
||||||
|
events_dp0 = []
|
||||||
|
events_dp1 = []
|
||||||
|
start = time.time()
|
||||||
|
max_wait_s = 10
|
||||||
|
min_events_per_rank = 3 # Expect at least a few events from each rank
|
||||||
|
|
||||||
|
while (time.time() - start) < max_wait_s and (
|
||||||
|
len(events_dp0) < min_events_per_rank
|
||||||
|
or len(events_dp1) < min_events_per_rank
|
||||||
|
):
|
||||||
|
# Check DP rank 0
|
||||||
|
if sub_dp0.poll(timeout=100): # 100ms timeout
|
||||||
|
_, seq_bytes, payload = sub_dp0.recv_multipart()
|
||||||
|
event_batch = decoder.decode(payload)
|
||||||
|
print(
|
||||||
|
f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
event_batch.attn_dp_rank,
|
||||||
|
0,
|
||||||
|
"DP rank 0 events should have attn_dp_rank=0",
|
||||||
|
)
|
||||||
|
for event in event_batch.events:
|
||||||
|
print(f" DP0 - {event}")
|
||||||
|
events_dp0.append(event)
|
||||||
|
|
||||||
|
# Check DP rank 1
|
||||||
|
if sub_dp1.poll(timeout=100): # 100ms timeout
|
||||||
|
_, seq_bytes, payload = sub_dp1.recv_multipart()
|
||||||
|
event_batch = decoder.decode(payload)
|
||||||
|
print(
|
||||||
|
f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
event_batch.attn_dp_rank,
|
||||||
|
1,
|
||||||
|
"DP rank 1 events should have attn_dp_rank=1",
|
||||||
|
)
|
||||||
|
for event in event_batch.events:
|
||||||
|
print(f" DP1 - {event}")
|
||||||
|
events_dp1.append(event)
|
||||||
|
|
||||||
|
# Verify we got events from both DP ranks
|
||||||
|
print(f"Collected {len(events_dp0)} events from DP rank 0")
|
||||||
|
print(f"Collected {len(events_dp1)} events from DP rank 1")
|
||||||
|
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
len(events_dp0),
|
||||||
|
min_events_per_rank,
|
||||||
|
f"Expected at least {min_events_per_rank} events from DP rank 0",
|
||||||
|
)
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
len(events_dp1),
|
||||||
|
min_events_per_rank,
|
||||||
|
f"Expected at least {min_events_per_rank} events from DP rank 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify event types are as expected
|
||||||
|
for events in [events_dp0, events_dp1]:
|
||||||
|
for event in events:
|
||||||
|
self.assertIsInstance(
|
||||||
|
event,
|
||||||
|
(BlockStored, BlockRemoved, AllBlocksCleared),
|
||||||
|
f"Event should be a KV cache event, got {type(event)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
sub_dp0.close()
|
||||||
|
sub_dp1.close()
|
||||||
|
context.term()
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user