[Misc] gen kv events in ascendconnector (#6593)
### What this PR does / why we need it?
refer to https://github.com/vllm-project/vllm-ascend/issues/6391,
Currently adapted the complete process of event publishing in vllm:
* `kv_connector_model_runner_mixin` invoke kv-connector
`get_kv_connector_kv_cache_events` func to collect kvevents
* in `scheduler.py` , it's `update_from_output` func will invoke
`_update_from_kv_xfer_finished` which invoke
`connector.update_connector_output` to collect kv-events from all
kv-worker, and then scheduler will invoke `connector.take_events` api to
collect all kv-events and add it to the events which from
`kv_cache_manager`
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
You can add `--kv-events-config` parameter to the `vllm server` command
to enable this feature.
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: yejj710 <abyss1999@163.com>
Co-authored-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import logger
|
||||
@@ -12,6 +18,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
|
||||
@@ -22,6 +29,40 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler imp
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
|
||||
|
||||
|
||||
class AscendStoreKVEvents(KVConnectorKVEvents):
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "AscendStoreKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AscendStoreKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class AscendStoreConnector(KVConnectorBase_V1):
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
|
||||
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
|
||||
@@ -40,6 +81,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self._kv_cache_events: AscendStoreKVEvents | None = None
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
@@ -82,6 +124,39 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, AscendStoreKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(kv_cache_events.get_number_of_workers())
|
||||
return
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
@@ -127,6 +202,18 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
)
|
||||
return done_sending, done_recving
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> AscendStoreKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
events = self.connector_worker.get_kv_events()
|
||||
if not events:
|
||||
return None
|
||||
|
||||
ascend_store_kv_events = AscendStoreKVEvents(num_workers=1)
|
||||
ascend_store_kv_events.add_events(events)
|
||||
return ascend_store_kv_events
|
||||
|
||||
|
||||
class LookupKeyServer:
|
||||
def __init__(
|
||||
|
||||
@@ -221,7 +221,6 @@ class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_len: int
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
@@ -233,6 +232,10 @@ class RequestTracker:
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
# NOTE: This field will only be used when you enable kv-event
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
@@ -256,6 +259,7 @@ class RequestTracker:
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(),
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
@@ -268,7 +272,6 @@ class RequestTracker:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
@@ -284,7 +287,7 @@ class RequestTracker:
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
# Number of tokens in this chunk
|
||||
token_len_chunk: int
|
||||
|
||||
block_ids: list[int]
|
||||
@@ -299,6 +302,11 @@ class ReqMeta:
|
||||
|
||||
current_event: torch.npu.Event | None = None
|
||||
|
||||
# The following parameters are only used for kv event generation
|
||||
# TODO: add lora_request which used for gen lora_id/lora_name in kv event
|
||||
token_ids: list[int] | None = None
|
||||
original_block_size: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
@@ -308,15 +316,18 @@ class ReqMeta:
|
||||
block_hashes: list[BlockHash] | None = None,
|
||||
is_last_chunk: bool | None = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
original_block_size: int | None = None,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
block_size (int): the block size in vLLM scheduler and AscendConnector.
|
||||
If context parallelism is enabled, block_size = block_size * pcp_size * dcp_size.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
original_block_size (int | None): the block size in vLLM worker. This is only used for kv event generation.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
@@ -342,6 +353,11 @@ class ReqMeta:
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
|
||||
# Get the token ids for kv event generation in kv_transfer
|
||||
token_ids = None
|
||||
if tracker.token_ids:
|
||||
token_ids = tracker.token_ids
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
@@ -361,6 +377,8 @@ class ReqMeta:
|
||||
load_spec=load_spec,
|
||||
block_hashes=block_hashes,
|
||||
is_last_chunk=is_last_chunk,
|
||||
token_ids=token_ids,
|
||||
original_block_size=original_block_size,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_events import BlockStored
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import maybe_convert_block_hash
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||
|
||||
@@ -41,6 +43,8 @@ class KVTransferThread(threading.Thread):
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
self.kv_event_lock = threading.Lock()
|
||||
self.kv_events: list[BlockStored] = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@@ -101,6 +105,16 @@ class KVTransferThread(threading.Thread):
|
||||
return 0
|
||||
return len(keys)
|
||||
|
||||
def update_kv_event(self, event: list[BlockStored]):
|
||||
with self.kv_event_lock:
|
||||
self.kv_events.extend(event)
|
||||
|
||||
def get_kv_events(self) -> list[BlockStored]:
|
||||
with self.kv_event_lock:
|
||||
events = self.kv_events.copy()
|
||||
self.kv_events.clear()
|
||||
return events
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
def __init__(
|
||||
@@ -113,6 +127,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
put_step: int,
|
||||
kv_role: str,
|
||||
ready_event: threading.Event,
|
||||
enable_kv_event: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
|
||||
@@ -120,6 +135,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
self.put_step = put_step
|
||||
self.kv_role = kv_role
|
||||
self.stored_requests = defaultdict[str, int](int)
|
||||
self.enable_kv_event = enable_kv_event
|
||||
|
||||
def add_stored_request(self, req_id: str):
|
||||
with self.done_task_lock:
|
||||
@@ -188,11 +204,30 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
"""
|
||||
addrs = []
|
||||
sizes = []
|
||||
stored_events: list[BlockStored] = []
|
||||
prev_key = None
|
||||
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]]
|
||||
for index, start in enumerate(starts):
|
||||
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
|
||||
# Create KV event
|
||||
if self.enable_kv_event:
|
||||
token_ids = req_meta.token_ids[start : ends[index]] if req_meta.token_ids is not None else None
|
||||
stored_event = BlockStored(
|
||||
block_hashes=[new_block_hashes[index]],
|
||||
parent_block_hash=prev_key,
|
||||
token_ids=token_ids,
|
||||
block_size=req_meta.original_block_size,
|
||||
lora_id=None,
|
||||
medium="cpu",
|
||||
lora_name=None,
|
||||
)
|
||||
stored_events.append(stored_event)
|
||||
prev_key = new_block_hashes[index]
|
||||
logger.debug(f"Added kv cache event '{stored_event}' to kv cache events queue")
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
|
||||
|
||||
@@ -200,6 +235,10 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
current_event.synchronize()
|
||||
self.m_store.put(keys, addrs, sizes)
|
||||
|
||||
# TODO Query specific replica info to update the event
|
||||
if self.enable_kv_event and stored_events is not None:
|
||||
self.update_kv_event(stored_events)
|
||||
|
||||
self.dec_stored_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
@@ -253,12 +292,14 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
put_step: int,
|
||||
ready_event: threading.Event,
|
||||
num_layers: int,
|
||||
enable_kv_event: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
|
||||
)
|
||||
self.final_layer_id = num_layers - 1
|
||||
self.put_step = put_step
|
||||
self.enable_kv_event = enable_kv_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: ReqMeta
|
||||
|
||||
@@ -37,6 +37,7 @@ class KVPoolScheduler:
|
||||
self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
|
||||
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
|
||||
|
||||
self.original_block_size = vllm_config.cache_config.block_size
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
if self.pcp_size > 1:
|
||||
self._block_size *= self.pcp_size
|
||||
@@ -183,6 +184,7 @@ class KVPoolScheduler:
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
token_ids=request.prompt_token_ids[:num_tokens_to_compute].copy(),
|
||||
)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = (
|
||||
@@ -199,6 +201,7 @@ class KVPoolScheduler:
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
original_block_size=self.original_block_size,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
@@ -227,6 +230,7 @@ class KVPoolScheduler:
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=new_block_ids,
|
||||
num_saved_tokens=0,
|
||||
token_ids=request_real.prompt_token_ids[:num_tokens_to_compute].copy(),
|
||||
)
|
||||
self._request_trackers[req_id] = request_tracker
|
||||
last_chunk_tokens_num = (
|
||||
@@ -242,6 +246,7 @@ class KVPoolScheduler:
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
original_block_size=self.original_block_size,
|
||||
)
|
||||
|
||||
# decode/chunked request
|
||||
@@ -276,6 +281,7 @@ class KVPoolScheduler:
|
||||
block_hashes=request.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
original_block_size=self.original_block_size,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
@@ -299,7 +305,6 @@ class KVPoolScheduler:
|
||||
)
|
||||
|
||||
self._request_trackers[request_id] = request_tracker
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed.kv_events import BlockStored
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
@@ -74,6 +75,7 @@ class KVPoolWorker:
|
||||
"consumer_is_to_put", False
|
||||
)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
|
||||
self.original_block_size = vllm_config.cache_config.block_size
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
if self.pcp_size > 1:
|
||||
@@ -146,6 +148,10 @@ class KVPoolWorker:
|
||||
self.m_store = real_backend( # type: ignore[misc]
|
||||
parallel_config
|
||||
)
|
||||
kv_event_config = vllm_config.kv_events_config
|
||||
self.enable_kv_events = False
|
||||
if kv_event_config and kv_event_config.enable_kv_cache_events:
|
||||
self.enable_kv_events = True
|
||||
|
||||
self.kv_send_thread: KVTransferThread | None = None
|
||||
self.kv_recv_thread: KVTransferThread | None = None
|
||||
@@ -209,6 +215,7 @@ class KVPoolWorker:
|
||||
self.put_step,
|
||||
ready_event_sending,
|
||||
self.num_layers,
|
||||
self.enable_kv_events,
|
||||
)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
@@ -235,6 +242,7 @@ class KVPoolWorker:
|
||||
self.put_step,
|
||||
self.kv_role,
|
||||
ready_event_sending,
|
||||
self.enable_kv_events,
|
||||
)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
@@ -641,3 +649,10 @@ class KVPoolWorker:
|
||||
return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def get_kv_events(self) -> list[BlockStored]:
|
||||
if self.enable_kv_events and self.kv_send_thread is not None:
|
||||
# collect store kv events form sending thread
|
||||
events = self.kv_send_thread.get_kv_events()
|
||||
return events
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user