[Misc] gen kv events in ascendconnector (#6593)

### What this PR does / why we need it?
refer to https://github.com/vllm-project/vllm-ascend/issues/6391,
Currently adapted the complete process of event publishing in vllm:
* `kv_connector_model_runner_mixin` invoke kv-connector
`get_kv_connector_kv_cache_events` func to collect kvevents
* in `scheduler.py` , it's `update_from_output` func will invoke
`_update_from_kv_xfer_finished` which invoke
`connector.update_connector_output` to collect kv-events from all
kv-worker, and then scheduler will invoke `connector.take_events` api to
collect all kv-events and add it to the events which from
`kv_cache_manager`

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
You can add `--kv-events-config` parameter to the `vllm server` command
to enable this feature.

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: yejj710 <abyss1999@163.com>
Co-authored-by: fems14 <1804143737@qq.com>
This commit is contained in:
yejj
2026-02-12 11:01:09 +08:00
committed by GitHub
parent 7221045777
commit 8b23554741
5 changed files with 171 additions and 5 deletions

View File

@@ -1,9 +1,15 @@
import threading import threading
from collections.abc import Iterable
from typing import Any from typing import Any
import torch import torch
import zmq import zmq
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import (
KVCacheEvent,
KVConnectorKVEvents,
KVEventAggregator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import logger from vllm.logger import logger
@@ -12,6 +18,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.serial_utils import MsgpackDecoder
@@ -22,6 +29,40 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler imp
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
class AscendStoreKVEvents(KVConnectorKVEvents):
def __init__(self, num_workers: int) -> None:
self._aggregator = KVEventAggregator(num_workers)
def add_events(self, events: list[KVCacheEvent]) -> None:
self._aggregator.add_events(events)
def aggregate(self) -> "AscendStoreKVEvents":
"""
Aggregate KV events and retain only common events.
"""
common_events = self._aggregator.get_common_events()
self._aggregator.clear_events()
self._aggregator.add_events(common_events)
self._aggregator.reset_workers()
return self
def increment_workers(self, count: int = 1) -> None:
self._aggregator.increment_workers(count)
def get_all_events(self) -> list[KVCacheEvent]:
return self._aggregator.get_all_events()
def get_number_of_workers(self) -> int:
return self._aggregator.get_number_of_workers()
def clear_events(self) -> None:
self._aggregator.clear_events()
self._aggregator.reset_workers()
def __repr__(self) -> str:
return f"<AscendStoreKVEvents events={self.get_all_events()}>"
class AscendStoreConnector(KVConnectorBase_V1): class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
@@ -40,6 +81,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
) )
self.kv_caches: dict[str, torch.Tensor] = {} self.kv_caches: dict[str, torch.Tensor] = {}
self._kv_cache_events: AscendStoreKVEvents | None = None
self.sended_but_unfinished_reqs: set[str] = set() self.sended_but_unfinished_reqs: set[str] = set()
@@ -82,6 +124,39 @@ class AscendStoreConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids) return self.connector_scheduler.request_finished(request, block_ids)
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side connectors output.
"""
# Get the KV events
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(kv_cache_events, AscendStoreKVEvents):
return
if self._kv_cache_events is None:
self._kv_cache_events = kv_cache_events
else:
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
self._kv_cache_events.increment_workers(kv_cache_events.get_number_of_workers())
return
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if self._kv_cache_events is not None:
self._kv_cache_events.aggregate()
kv_cache_events = self._kv_cache_events.get_all_events()
yield from kv_cache_events
self._kv_cache_events.clear_events()
self._kv_cache_events = None
############################################################ ############################################################
# Worker Side Methods # Worker Side Methods
############################################################ ############################################################
@@ -127,6 +202,18 @@ class AscendStoreConnector(KVConnectorBase_V1):
) )
return done_sending, done_recving return done_sending, done_recving
def get_kv_connector_kv_cache_events(self) -> AscendStoreKVEvents | None:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events = self.connector_worker.get_kv_events()
if not events:
return None
ascend_store_kv_events = AscendStoreKVEvents(num_workers=1)
ascend_store_kv_events.add_events(events)
return ascend_store_kv_events
class LookupKeyServer: class LookupKeyServer:
def __init__( def __init__(

View File

@@ -221,7 +221,6 @@ class RequestTracker:
# Request id # Request id
req_id: str req_id: str
# The token ids that has been scheduled so far
token_len: int token_len: int
# The block ids that has been allocated so far # The block ids that has been allocated so far
@@ -233,6 +232,10 @@ class RequestTracker:
# The number of tokens that has been savd # The number of tokens that has been savd
num_saved_tokens: int = 0 num_saved_tokens: int = 0
# The token ids that has been scheduled so far
# NOTE: This field will only be used when you enable kv-event
token_ids: list[int] | None = None
@staticmethod @staticmethod
def from_new_request( def from_new_request(
new_request: "NewRequestData", new_request: "NewRequestData",
@@ -256,6 +259,7 @@ class RequestTracker:
return RequestTracker( return RequestTracker(
req_id=new_request.req_id, req_id=new_request.req_id,
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(),
token_len=num_tokens_to_compute, token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids, allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0, num_saved_tokens=0,
@@ -268,7 +272,6 @@ class RequestTracker:
"""Update the request tracker when a running request is """Update the request tracker when a running request is
scheduled again scheduled again
""" """
if len(new_block_ids) == 0: if len(new_block_ids) == 0:
new_block_ids = [] new_block_ids = []
elif isinstance(new_block_ids, tuple): elif isinstance(new_block_ids, tuple):
@@ -284,7 +287,7 @@ class RequestTracker:
class ReqMeta: class ReqMeta:
# Request id # Request id
req_id: str req_id: str
# Request tokens # Number of tokens in this chunk
token_len_chunk: int token_len_chunk: int
block_ids: list[int] block_ids: list[int]
@@ -299,6 +302,11 @@ class ReqMeta:
current_event: torch.npu.Event | None = None current_event: torch.npu.Event | None = None
# The following parameters are only used for kv event generation
# TODO: add lora_request which used for gen lora_id/lora_name in kv event
token_ids: list[int] | None = None
original_block_size: int | None = None
@staticmethod @staticmethod
def from_request_tracker( def from_request_tracker(
tracker: RequestTracker, tracker: RequestTracker,
@@ -308,15 +316,18 @@ class ReqMeta:
block_hashes: list[BlockHash] | None = None, block_hashes: list[BlockHash] | None = None,
is_last_chunk: bool | None = None, is_last_chunk: bool | None = None,
discard_partial_chunks: bool = True, discard_partial_chunks: bool = True,
original_block_size: int | None = None,
) -> Optional["ReqMeta"]: ) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker. """Create the request metadata from a request tracker.
Args: Args:
tracker (RequestTracker): the request tracker. tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM. block_size (int): the block size in vLLM scheduler and AscendConnector.
If context parallelism is enabled, block_size = block_size * pcp_size * dcp_size.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading. load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation. skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks. discard_partial_chunks (bool): whether to discard partial chunks.
original_block_size (int | None): the block size in vLLM worker. This is only used for kv event generation.
Returns: Returns:
the request metadata if we need to perform load/save the request metadata if we need to perform load/save
@@ -342,6 +353,11 @@ class ReqMeta:
if not skip_save: if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save tracker.num_saved_tokens = num_tokens_to_save
# Get the token ids for kv event generation in kv_transfer
token_ids = None
if tracker.token_ids:
token_ids = tracker.token_ids
# # For load operation: check whether the request is scheduled to load # # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load: if load_spec is not None and load_spec.can_load:
logger.debug( logger.debug(
@@ -361,6 +377,8 @@ class ReqMeta:
load_spec=load_spec, load_spec=load_spec,
block_hashes=block_hashes, block_hashes=block_hashes,
is_last_chunk=is_last_chunk, is_last_chunk=is_last_chunk,
token_ids=token_ids,
original_block_size=original_block_size,
) )

View File

@@ -5,7 +5,9 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
import torch import torch
from vllm.distributed.kv_events import BlockStored
from vllm.logger import logger from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import maybe_convert_block_hash
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
@@ -41,6 +43,8 @@ class KVTransferThread(threading.Thread):
# TODO(jianzs): make this configurable # TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32) self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set() self.finished_requests: set[str] = set()
self.kv_event_lock = threading.Lock()
self.kv_events: list[BlockStored] = []
def add_request( def add_request(
self, self,
@@ -101,6 +105,16 @@ class KVTransferThread(threading.Thread):
return 0 return 0
return len(keys) return len(keys)
def update_kv_event(self, event: list[BlockStored]):
with self.kv_event_lock:
self.kv_events.extend(event)
def get_kv_events(self) -> list[BlockStored]:
with self.kv_event_lock:
events = self.kv_events.copy()
self.kv_events.clear()
return events
class KVCacheStoreSendingThread(KVTransferThread): class KVCacheStoreSendingThread(KVTransferThread):
def __init__( def __init__(
@@ -113,6 +127,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
put_step: int, put_step: int,
kv_role: str, kv_role: str,
ready_event: threading.Event, ready_event: threading.Event,
enable_kv_event: bool = False,
): ):
super().__init__( super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread" m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
@@ -120,6 +135,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
self.put_step = put_step self.put_step = put_step
self.kv_role = kv_role self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int) self.stored_requests = defaultdict[str, int](int)
self.enable_kv_event = enable_kv_event
def add_stored_request(self, req_id: str): def add_stored_request(self, req_id: str):
with self.done_task_lock: with self.done_task_lock:
@@ -188,11 +204,30 @@ class KVCacheStoreSendingThread(KVTransferThread):
""" """
addrs = [] addrs = []
sizes = [] sizes = []
stored_events: list[BlockStored] = []
prev_key = None
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]]
for index, start in enumerate(starts): for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids) addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
addrs.append(addr) addrs.append(addr)
sizes.append(size) sizes.append(size)
# Create KV event
if self.enable_kv_event:
token_ids = req_meta.token_ids[start : ends[index]] if req_meta.token_ids is not None else None
stored_event = BlockStored(
block_hashes=[new_block_hashes[index]],
parent_block_hash=prev_key,
token_ids=token_ids,
block_size=req_meta.original_block_size,
lora_id=None,
medium="cpu",
lora_name=None,
)
stored_events.append(stored_event)
prev_key = new_block_hashes[index]
logger.debug(f"Added kv cache event '{stored_event}' to kv cache events queue")
if self.kv_role == "kv_consumer": if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes) keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
@@ -200,6 +235,10 @@ class KVCacheStoreSendingThread(KVTransferThread):
current_event.synchronize() current_event.synchronize()
self.m_store.put(keys, addrs, sizes) self.m_store.put(keys, addrs, sizes)
# TODO Query specific replica info to update the event
if self.enable_kv_event and stored_events is not None:
self.update_kv_event(stored_events)
self.dec_stored_request(req_id) self.dec_stored_request(req_id)
self.request_queue.task_done() self.request_queue.task_done()
@@ -253,12 +292,14 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
put_step: int, put_step: int,
ready_event: threading.Event, ready_event: threading.Event,
num_layers: int, num_layers: int,
enable_kv_event: bool = False,
): ):
super().__init__( super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread" m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
) )
self.final_layer_id = num_layers - 1 self.final_layer_id = num_layers - 1
self.put_step = put_step self.put_step = put_step
self.enable_kv_event = enable_kv_event
def add_request( # type: ignore[override] def add_request( # type: ignore[override]
self, req_meta: ReqMeta self, req_meta: ReqMeta

View File

@@ -37,6 +37,7 @@ class KVPoolScheduler:
self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1) self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1) self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
self.original_block_size = vllm_config.cache_config.block_size
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1: if self.pcp_size > 1:
self._block_size *= self.pcp_size self._block_size *= self.pcp_size
@@ -183,6 +184,7 @@ class KVPoolScheduler:
token_len=num_tokens_to_compute, token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids, allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0, num_saved_tokens=0,
token_ids=request.prompt_token_ids[:num_tokens_to_compute].copy(),
) )
self._request_trackers[request.req_id] = request_tracker self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ( last_chunk_tokens_num = (
@@ -199,6 +201,7 @@ class KVPoolScheduler:
block_hashes=request_real.block_hashes, block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks, discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
) )
if req_meta is not None: if req_meta is not None:
meta.add_request(req_meta) meta.add_request(req_meta)
@@ -227,6 +230,7 @@ class KVPoolScheduler:
token_len=num_tokens_to_compute, token_len=num_tokens_to_compute,
allocated_block_ids=new_block_ids, allocated_block_ids=new_block_ids,
num_saved_tokens=0, num_saved_tokens=0,
token_ids=request_real.prompt_token_ids[:num_tokens_to_compute].copy(),
) )
self._request_trackers[req_id] = request_tracker self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = ( last_chunk_tokens_num = (
@@ -242,6 +246,7 @@ class KVPoolScheduler:
block_hashes=request_real.block_hashes, block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks, discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
) )
# decode/chunked request # decode/chunked request
@@ -276,6 +281,7 @@ class KVPoolScheduler:
block_hashes=request.block_hashes, block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks, discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
) )
if req_meta is not None: if req_meta is not None:
meta.add_request(req_meta) meta.add_request(req_meta)
@@ -299,7 +305,6 @@ class KVPoolScheduler:
) )
self._request_trackers[request_id] = request_tracker self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker( req_meta = ReqMeta.from_request_tracker(
request_tracker, request_tracker,
self._block_size, self._block_size,

View File

@@ -12,6 +12,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.distributed.kv_events import BlockStored
from vllm.logger import logger from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
@@ -74,6 +75,7 @@ class KVPoolWorker:
"consumer_is_to_put", False "consumer_is_to_put", False
) )
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake") self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
self.original_block_size = vllm_config.cache_config.block_size
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1: if self.pcp_size > 1:
@@ -146,6 +148,10 @@ class KVPoolWorker:
self.m_store = real_backend( # type: ignore[misc] self.m_store = real_backend( # type: ignore[misc]
parallel_config parallel_config
) )
kv_event_config = vllm_config.kv_events_config
self.enable_kv_events = False
if kv_event_config and kv_event_config.enable_kv_cache_events:
self.enable_kv_events = True
self.kv_send_thread: KVTransferThread | None = None self.kv_send_thread: KVTransferThread | None = None
self.kv_recv_thread: KVTransferThread | None = None self.kv_recv_thread: KVTransferThread | None = None
@@ -209,6 +215,7 @@ class KVPoolWorker:
self.put_step, self.put_step,
ready_event_sending, ready_event_sending,
self.num_layers, self.num_layers,
self.enable_kv_events,
) )
self.kv_send_thread.start() self.kv_send_thread.start()
ready_event = threading.Event() ready_event = threading.Event()
@@ -235,6 +242,7 @@ class KVPoolWorker:
self.put_step, self.put_step,
self.kv_role, self.kv_role,
ready_event_sending, ready_event_sending,
self.enable_kv_events,
) )
self.kv_send_thread.start() self.kv_send_thread.start()
if self.load_async: if self.load_async:
@@ -641,3 +649,10 @@ class KVPoolWorker:
return min(idx for row in arr for idx, val in enumerate(row) if val != 1) return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
except ValueError: except ValueError:
return -1 return -1
def get_kv_events(self) -> list[BlockStored]:
if self.enable_kv_events and self.kv_send_thread is not None:
# collect store kv events form sending thread
events = self.kv_send_thread.get_kv_events()
return events
return []