[Misc] gen kv events in ascendconnector (#6593)

### What this PR does / why we need it?
refer to https://github.com/vllm-project/vllm-ascend/issues/6391,
Currently adapted the complete process of event publishing in vllm:
* `kv_connector_model_runner_mixin` invoke kv-connector
`get_kv_connector_kv_cache_events` func to collect kvevents
* in `scheduler.py` , it's `update_from_output` func will invoke
`_update_from_kv_xfer_finished` which invoke
`connector.update_connector_output` to collect kv-events from all
kv-worker, and then scheduler will invoke `connector.take_events` api to
collect all kv-events and add it to the events which from
`kv_cache_manager`

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
You can add `--kv-events-config` parameter to the `vllm server` command
to enable this feature.

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: yejj710 <abyss1999@163.com>
Co-authored-by: fems14 <1804143737@qq.com>
This commit is contained in:
yejj
2026-02-12 11:01:09 +08:00
committed by GitHub
parent 7221045777
commit 8b23554741
5 changed files with 171 additions and 5 deletions

View File

@@ -1,9 +1,15 @@
import threading
from collections.abc import Iterable
from typing import Any
import torch
import zmq
from vllm.config import VllmConfig
from vllm.distributed.kv_events import (
KVCacheEvent,
KVConnectorKVEvents,
KVEventAggregator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.forward_context import ForwardContext
from vllm.logger import logger
@@ -12,6 +18,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder
@@ -22,6 +29,40 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler imp
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
class AscendStoreKVEvents(KVConnectorKVEvents):
def __init__(self, num_workers: int) -> None:
self._aggregator = KVEventAggregator(num_workers)
def add_events(self, events: list[KVCacheEvent]) -> None:
self._aggregator.add_events(events)
def aggregate(self) -> "AscendStoreKVEvents":
"""
Aggregate KV events and retain only common events.
"""
common_events = self._aggregator.get_common_events()
self._aggregator.clear_events()
self._aggregator.add_events(common_events)
self._aggregator.reset_workers()
return self
def increment_workers(self, count: int = 1) -> None:
self._aggregator.increment_workers(count)
def get_all_events(self) -> list[KVCacheEvent]:
return self._aggregator.get_all_events()
def get_number_of_workers(self) -> int:
return self._aggregator.get_number_of_workers()
def clear_events(self) -> None:
self._aggregator.clear_events()
self._aggregator.reset_workers()
def __repr__(self) -> str:
return f"<AscendStoreKVEvents events={self.get_all_events()}>"
class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
@@ -40,6 +81,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
)
self.kv_caches: dict[str, torch.Tensor] = {}
self._kv_cache_events: AscendStoreKVEvents | None = None
self.sended_but_unfinished_reqs: set[str] = set()
@@ -82,6 +124,39 @@ class AscendStoreConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side connectors output.
"""
# Get the KV events
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(kv_cache_events, AscendStoreKVEvents):
return
if self._kv_cache_events is None:
self._kv_cache_events = kv_cache_events
else:
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
self._kv_cache_events.increment_workers(kv_cache_events.get_number_of_workers())
return
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if self._kv_cache_events is not None:
self._kv_cache_events.aggregate()
kv_cache_events = self._kv_cache_events.get_all_events()
yield from kv_cache_events
self._kv_cache_events.clear_events()
self._kv_cache_events = None
############################################################
# Worker Side Methods
############################################################
@@ -127,6 +202,18 @@ class AscendStoreConnector(KVConnectorBase_V1):
)
return done_sending, done_recving
def get_kv_connector_kv_cache_events(self) -> AscendStoreKVEvents | None:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events = self.connector_worker.get_kv_events()
if not events:
return None
ascend_store_kv_events = AscendStoreKVEvents(num_workers=1)
ascend_store_kv_events.add_events(events)
return ascend_store_kv_events
class LookupKeyServer:
def __init__(

View File

@@ -221,7 +221,6 @@ class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_len: int
# The block ids that has been allocated so far
@@ -233,6 +232,10 @@ class RequestTracker:
# The number of tokens that has been savd
num_saved_tokens: int = 0
# The token ids that has been scheduled so far
# NOTE: This field will only be used when you enable kv-event
token_ids: list[int] | None = None
@staticmethod
def from_new_request(
new_request: "NewRequestData",
@@ -256,6 +259,7 @@ class RequestTracker:
return RequestTracker(
req_id=new_request.req_id,
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(),
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
@@ -268,7 +272,6 @@ class RequestTracker:
"""Update the request tracker when a running request is
scheduled again
"""
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
@@ -284,7 +287,7 @@ class RequestTracker:
class ReqMeta:
# Request id
req_id: str
# Request tokens
# Number of tokens in this chunk
token_len_chunk: int
block_ids: list[int]
@@ -299,6 +302,11 @@ class ReqMeta:
current_event: torch.npu.Event | None = None
# The following parameters are only used for kv event generation
# TODO: add lora_request which used for gen lora_id/lora_name in kv event
token_ids: list[int] | None = None
original_block_size: int | None = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
@@ -308,15 +316,18 @@ class ReqMeta:
block_hashes: list[BlockHash] | None = None,
is_last_chunk: bool | None = None,
discard_partial_chunks: bool = True,
original_block_size: int | None = None,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
block_size (int): the block size in vLLM scheduler and AscendConnector.
If context parallelism is enabled, block_size = block_size * pcp_size * dcp_size.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
original_block_size (int | None): the block size in vLLM worker. This is only used for kv event generation.
Returns:
the request metadata if we need to perform load/save
@@ -342,6 +353,11 @@ class ReqMeta:
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
# Get the token ids for kv event generation in kv_transfer
token_ids = None
if tracker.token_ids:
token_ids = tracker.token_ids
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
@@ -361,6 +377,8 @@ class ReqMeta:
load_spec=load_spec,
block_hashes=block_hashes,
is_last_chunk=is_last_chunk,
token_ids=token_ids,
original_block_size=original_block_size,
)

View File

@@ -5,7 +5,9 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Any
import torch
from vllm.distributed.kv_events import BlockStored
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import maybe_convert_block_hash
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
@@ -41,6 +43,8 @@ class KVTransferThread(threading.Thread):
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
self.kv_event_lock = threading.Lock()
self.kv_events: list[BlockStored] = []
def add_request(
self,
@@ -101,6 +105,16 @@ class KVTransferThread(threading.Thread):
return 0
return len(keys)
def update_kv_event(self, event: list[BlockStored]):
with self.kv_event_lock:
self.kv_events.extend(event)
def get_kv_events(self) -> list[BlockStored]:
with self.kv_event_lock:
events = self.kv_events.copy()
self.kv_events.clear()
return events
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(
@@ -113,6 +127,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
put_step: int,
kv_role: str,
ready_event: threading.Event,
enable_kv_event: bool = False,
):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
@@ -120,6 +135,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
self.put_step = put_step
self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int)
self.enable_kv_event = enable_kv_event
def add_stored_request(self, req_id: str):
with self.done_task_lock:
@@ -188,11 +204,30 @@ class KVCacheStoreSendingThread(KVTransferThread):
"""
addrs = []
sizes = []
stored_events: list[BlockStored] = []
prev_key = None
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]]
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
# Create KV event
if self.enable_kv_event:
token_ids = req_meta.token_ids[start : ends[index]] if req_meta.token_ids is not None else None
stored_event = BlockStored(
block_hashes=[new_block_hashes[index]],
parent_block_hash=prev_key,
token_ids=token_ids,
block_size=req_meta.original_block_size,
lora_id=None,
medium="cpu",
lora_name=None,
)
stored_events.append(stored_event)
prev_key = new_block_hashes[index]
logger.debug(f"Added kv cache event '{stored_event}' to kv cache events queue")
if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
@@ -200,6 +235,10 @@ class KVCacheStoreSendingThread(KVTransferThread):
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)
# TODO Query specific replica info to update the event
if self.enable_kv_event and stored_events is not None:
self.update_kv_event(stored_events)
self.dec_stored_request(req_id)
self.request_queue.task_done()
@@ -253,12 +292,14 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
put_step: int,
ready_event: threading.Event,
num_layers: int,
enable_kv_event: bool = False,
):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
)
self.final_layer_id = num_layers - 1
self.put_step = put_step
self.enable_kv_event = enable_kv_event
def add_request( # type: ignore[override]
self, req_meta: ReqMeta

View File

@@ -37,6 +37,7 @@ class KVPoolScheduler:
self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
self.original_block_size = vllm_config.cache_config.block_size
self._block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
self._block_size *= self.pcp_size
@@ -183,6 +184,7 @@ class KVPoolScheduler:
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
token_ids=request.prompt_token_ids[:num_tokens_to_compute].copy(),
)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = (
@@ -199,6 +201,7 @@ class KVPoolScheduler:
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
if req_meta is not None:
meta.add_request(req_meta)
@@ -227,6 +230,7 @@ class KVPoolScheduler:
token_len=num_tokens_to_compute,
allocated_block_ids=new_block_ids,
num_saved_tokens=0,
token_ids=request_real.prompt_token_ids[:num_tokens_to_compute].copy(),
)
self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = (
@@ -242,6 +246,7 @@ class KVPoolScheduler:
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
# decode/chunked request
@@ -276,6 +281,7 @@ class KVPoolScheduler:
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
if req_meta is not None:
meta.add_request(req_meta)
@@ -299,7 +305,6 @@ class KVPoolScheduler:
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,

View File

@@ -12,6 +12,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.kv_events import BlockStored
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash
@@ -74,6 +75,7 @@ class KVPoolWorker:
"consumer_is_to_put", False
)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
self.original_block_size = vllm_config.cache_config.block_size
self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
@@ -146,6 +148,10 @@ class KVPoolWorker:
self.m_store = real_backend( # type: ignore[misc]
parallel_config
)
kv_event_config = vllm_config.kv_events_config
self.enable_kv_events = False
if kv_event_config and kv_event_config.enable_kv_cache_events:
self.enable_kv_events = True
self.kv_send_thread: KVTransferThread | None = None
self.kv_recv_thread: KVTransferThread | None = None
@@ -209,6 +215,7 @@ class KVPoolWorker:
self.put_step,
ready_event_sending,
self.num_layers,
self.enable_kv_events,
)
self.kv_send_thread.start()
ready_event = threading.Event()
@@ -235,6 +242,7 @@ class KVPoolWorker:
self.put_step,
self.kv_role,
ready_event_sending,
self.enable_kv_events,
)
self.kv_send_thread.start()
if self.load_async:
@@ -641,3 +649,10 @@ class KVPoolWorker:
return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
except ValueError:
return -1
def get_kv_events(self) -> list[BlockStored]:
if self.enable_kv_events and self.kv_send_thread is not None:
# collect store kv events form sending thread
events = self.kv_send_thread.get_kv_events()
return events
return []