feat: add dp-rank to KV events (#6852)
This commit is contained in:
@@ -43,6 +43,7 @@ class EventBatch(
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
attn_dp_rank: Optional[int] = None
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch):
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches."""
|
||||
"""
|
||||
Lightweight publisher for EventBatch batches with
|
||||
support for DP attention.
|
||||
|
||||
In DP attention - each rank has its own Scheduler and
|
||||
KV cache instance in order to avoid duplicate events
|
||||
and ensure proper event attribution. In our implementation
|
||||
|
||||
- Each DP rank has its own EventPublisher
|
||||
- Publishers annotate events with the dp rank
|
||||
- This allows consumers to distinguish events from different DP ranks
|
||||
"""
|
||||
|
||||
def __init__(self, attn_dp_rank: int = 0):
|
||||
self._attn_dp_rank = attn_dp_rank
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_dp_rank: int,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: Optional[str] = None,
|
||||
buffer_steps: int = 10_000,
|
||||
@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher):
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
super().__init__(attn_dp_rank)
|
||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher):
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: Optional[zmq.Socket] = None
|
||||
self._replay: Optional[zmq.Socket] = None
|
||||
self._endpoint = endpoint
|
||||
self._replay_endpoint = replay_endpoint
|
||||
self._dp_rank = attn_dp_rank
|
||||
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||
self._replay_endpoint = self.offset_endpoint_port(
|
||||
replay_endpoint, self._dp_rank
|
||||
)
|
||||
self._hwm = hwm
|
||||
self._socket_setup()
|
||||
|
||||
@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher):
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
if events.attn_dp_rank is None:
|
||||
events.attn_dp_rank = self._dp_rank
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher):
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
@staticmethod
|
||||
def offset_endpoint_port(
|
||||
endpoint: Optional[str], data_parallel_rank: int
|
||||
) -> Optional[str]:
|
||||
"""Helper function to offset the port in an endpoint by
|
||||
the data parallel rank.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint string
|
||||
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||
data_parallel_rank: The data parallel rank to offset by
|
||||
|
||||
Returns:
|
||||
The endpoint with the port offset by data_parallel_rank
|
||||
or suffix appended
|
||||
"""
|
||||
# Do nothing if input is None or data_parallel_rank is 0
|
||||
if not endpoint or data_parallel_rank == 0:
|
||||
return endpoint
|
||||
|
||||
if "inproc" in endpoint:
|
||||
return f"{endpoint}_dp{data_parallel_rank}"
|
||||
if "tcp" in endpoint:
|
||||
if endpoint and ":" in endpoint:
|
||||
# Get everything after the last colon (the port)
|
||||
last_colon_idx = endpoint.rfind(":")
|
||||
base_addr = endpoint[:last_colon_idx]
|
||||
base_port = int(endpoint[last_colon_idx + 1 :])
|
||||
new_port = base_port + data_parallel_rank
|
||||
return f"{base_addr}:{new_port}"
|
||||
return endpoint
|
||||
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||
|
||||
|
||||
class KVEventsConfig(BaseModel):
|
||||
"""Configuration for KV event publishing."""
|
||||
@@ -342,7 +397,7 @@ class EventPublisherFactory:
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: Optional[str]) -> EventPublisher:
|
||||
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
@@ -354,4 +409,4 @@ class EventPublisherFactory:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(**config_dict)
|
||||
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)
|
||||
|
||||
@@ -571,7 +571,9 @@ class Scheduler(
|
||||
|
||||
def init_kv_events(self, kv_events_config: Optional[str]):
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
kv_events_config, self.attn_dp_rank
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
self.transfer_backend = TransferBackend(
|
||||
@@ -1988,7 +1990,7 @@ class Scheduler(
|
||||
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
||||
for k, v in server_args_dict.items():
|
||||
global_server_args_dict[k] = v
|
||||
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
|
||||
logger.info(f"Global server args updated! {global_server_args_dict=}")
|
||||
return SetInternalStateReqOutput(
|
||||
updated=True,
|
||||
server_args=global_server_args_dict,
|
||||
|
||||
Reference in New Issue
Block a user