This PR let pooling kv connector support pcp feature - vLLM version: v0.11.2 --------- Signed-off-by: fjw <2270923832@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
import threading
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import zmq
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.utils import logger
|
|
from vllm.utils.network_utils import make_zmq_socket
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import Request
|
|
from vllm.v1.serial_utils import MsgpackDecoder
|
|
|
|
from vllm_ascend.distributed.kvpool.pool_scheduler import (
|
|
KVPoolScheduler, get_zmq_rpc_path_lookup)
|
|
from vllm_ascend.distributed.kvpool.pool_worker import KVPoolWorker
|
|
|
|
|
|
class AscendStoreConnector(KVConnectorBase_V1):
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
role: KVConnectorRole,
|
|
kv_cache_config: Optional[KVCacheConfig] = None):
|
|
super().__init__(vllm_config=vllm_config,
|
|
role=role,
|
|
kv_cache_config=kv_cache_config)
|
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
|
|
|
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
|
"use_layerwise", False)
|
|
|
|
connector_name = vllm_config.kv_transfer_config.kv_connector
|
|
if connector_name == "MooncakeConnectorStoreV1":
|
|
logger.warning(
|
|
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
|
|
)
|
|
|
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
|
|
|
self.sended_but_unfinished_reqs: set[str] = set()
|
|
|
|
if role == KVConnectorRole.SCHEDULER:
|
|
self.connector_scheduler = KVPoolScheduler(vllm_config,
|
|
self.use_layerwise)
|
|
else:
|
|
self.connector_worker = KVPoolWorker(
|
|
vllm_config,
|
|
self.use_layerwise,
|
|
)
|
|
|
|
assert self.connector_worker is not None
|
|
if vllm_config.parallel_config.rank == 0:
|
|
self.lookup_server = LookupKeyServer(self.connector_worker,
|
|
vllm_config,
|
|
self.use_layerwise)
|
|
|
|
############################################################
|
|
# Scheduler Side Methods
|
|
############################################################
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request",
|
|
num_computed_tokens: int) -> tuple[int, bool]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
|
request, num_computed_tokens)
|
|
|
|
def update_state_after_alloc(self, request: "Request",
|
|
blocks: "KVCacheBlocks",
|
|
num_external_tokens: int):
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.update_state_after_alloc(
|
|
request, blocks, num_external_tokens)
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.request_finished(request, block_ids)
|
|
|
|
############################################################
|
|
# Worker Side Methods
|
|
############################################################
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
assert self.connector_worker is not None
|
|
self.connector_worker.register_kv_caches(kv_caches)
|
|
|
|
def start_load_kv(self, forward_context: "ForwardContext",
|
|
**kwargs) -> None:
|
|
assert self.connector_worker is not None
|
|
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
|
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
if not self.use_layerwise:
|
|
return
|
|
self.connector_worker.wait_for_layer_load()
|
|
|
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
|
if not self.use_layerwise:
|
|
return
|
|
|
|
if self.kv_role == "kv_consumer":
|
|
# Don't do save if the role is kv_consumer
|
|
return
|
|
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
|
|
|
def wait_for_save(self):
|
|
if self.kv_role == "kv_consumer":
|
|
# Don't do save if the role is kv_consumer
|
|
return
|
|
|
|
if self.use_layerwise:
|
|
return
|
|
|
|
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
|
|
|
def get_finished(self,
|
|
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
|
"""Get the finished recving and sending requests."""
|
|
assert self.connector_worker is not None
|
|
meta = self._get_connector_metadata()
|
|
done_sending, done_recving = self.connector_worker.get_finished()
|
|
sended_and_finished: set[str] = set()
|
|
for item in list(self.sended_but_unfinished_reqs):
|
|
if item not in meta.unfinished_request_ids:
|
|
sended_and_finished.add(item)
|
|
self.sended_but_unfinished_reqs.remove(item)
|
|
for item in done_sending:
|
|
if item in meta.unfinished_request_ids:
|
|
self.sended_but_unfinished_reqs.add(item)
|
|
else:
|
|
sended_and_finished.add(item)
|
|
|
|
return sended_and_finished, done_recving
|
|
|
|
|
|
class LookupKeyServer:
|
|
|
|
def __init__(
|
|
self,
|
|
pool_worker: KVPoolWorker,
|
|
vllm_config: "VllmConfig",
|
|
use_layerwise: bool,
|
|
):
|
|
self.decoder = MsgpackDecoder()
|
|
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
|
|
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
|
socket_path = get_zmq_rpc_path_lookup(vllm_config)
|
|
self.socket = make_zmq_socket(
|
|
self.ctx,
|
|
socket_path,
|
|
zmq.REP, # type: ignore[attr-defined]
|
|
bind=True,
|
|
)
|
|
|
|
self.pool_worker = pool_worker
|
|
self.running = True
|
|
self.use_layerwise = use_layerwise
|
|
|
|
def process_request():
|
|
while self.running:
|
|
all_frames = self.socket.recv_multipart(copy=False)
|
|
token_len = int.from_bytes(all_frames[0], byteorder="big")
|
|
hash_frames = all_frames[1:]
|
|
hashes_str = self.decoder.decode(hash_frames)
|
|
result = self.pool_worker.lookup_scheduler(
|
|
token_len, hashes_str, self.use_layerwise)
|
|
response = result.to_bytes(4, "big")
|
|
self.socket.send(response)
|
|
|
|
self.thread = threading.Thread(target=process_request, daemon=True)
|
|
self.thread.start()
|
|
|
|
def close(self):
|
|
self.socket.close(linger=0)
|
|
# TODO: close the thread!
|