feat: add dp-rank to KV events (#6852)

This commit is contained in:
ishandhanani
2025-06-04 15:29:34 -07:00
committed by GitHub
parent 3f1e433903
commit f0f84975f4
3 changed files with 195 additions and 8 deletions

View File

@@ -43,6 +43,7 @@ class EventBatch(
):
ts: float
events: list[Any]
attn_dp_rank: Optional[int] = None
class KVCacheEvent(
@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch):
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches."""
"""
Lightweight publisher for EventBatch batches with
support for DP attention.
In DP attention - each rank has its own Scheduler and
KV cache instance in order to avoid duplicate events
and ensure proper event attribution. In our implementation
- Each DP rank has its own EventPublisher
- Publishers annotate events with the dp rank
- This allows consumers to distinguish events from different DP ranks
"""
def __init__(self, attn_dp_rank: int = 0):
self._attn_dp_rank = attn_dp_rank
@abstractmethod
def publish(self, events: EventBatch) -> None:
@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher):
def __init__(
self,
attn_dp_rank: int,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher):
topic: str = "",
) -> None:
# Storage
super().__init__(attn_dp_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher):
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint
self._replay_endpoint = replay_endpoint
self._dp_rank = attn_dp_rank
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank
)
self._hwm = hwm
self._socket_setup()
@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher):
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
if events.attn_dp_rank is None:
events.attn_dp_rank = self._dp_rank
self._event_queue.put(events)
def shutdown(self) -> None:
@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(
endpoint: Optional[str], data_parallel_rank: int
) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1 :])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class KVEventsConfig(BaseModel):
"""Configuration for KV event publishing."""
@@ -342,7 +397,7 @@ class EventPublisherFactory:
cls._registry[name] = ctor
@classmethod
def create(cls, config: Optional[str]) -> EventPublisher:
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
@@ -354,4 +409,4 @@ class EventPublisherFactory:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict)
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)

View File

@@ -571,7 +571,9 @@ class Scheduler(
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def init_disaggregation(self):
self.transfer_backend = TransferBackend(
@@ -1988,7 +1990,7 @@ class Scheduler(
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items():
global_server_args_dict[k] = v
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
logger.info(f"Global server args updated! {global_server_args_dict=}")
return SetInternalStateReqOutput(
updated=True,
server_args=global_server_args_dict,