[Feature][main]reconstruction kvpool connector to ascend connector (#4438)
### What this PR does / why we need it? 1.In short, we renamed the existing MooncakeStoreConnector to AscendStoreConnector and extracted the storage engine interaction logic into a new Backend class. Associated RFC:https://github.com/vllm-project/vllm-ascend/issues/4329 2.Fixed the issue where the number of input parameters for the connector was incorrect, introduced in vllm 0.11.2 ### Does this PR introduce _any_ user-facing change? change MooncakeStoreConnector to AscendStoreConnector ### How was this patch tested? - vLLM version: v0.11.2 --------- Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -31,11 +31,12 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tp_group)
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
@@ -634,7 +635,10 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
class MooncakeConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
@@ -944,7 +948,7 @@ class MooncakeConnectorWorker:
|
||||
else:
|
||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
||||
logger.info("Initializing Mooncake work %s", engine_id)
|
||||
self.engine = get_global_te(hostname, device_name=None)
|
||||
self.engine = global_te.get_transfer_engine(hostname, device_name=None)
|
||||
self.te_rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
# Background thread for sending or receiving KV caches.
|
||||
@@ -1054,6 +1058,8 @@ class MooncakeConnectorWorker:
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
@@ -1061,13 +1067,15 @@ class MooncakeConnectorWorker:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
@@ -1076,8 +1084,9 @@ class MooncakeConnectorWorker:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
# After KV Caches registered, start the sending or receiving thread.
|
||||
metadata = MooncakeAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
@@ -1101,14 +1110,6 @@ class MooncakeConnectorWorker:
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _register(self, ptr, length):
|
||||
logger.debug(
|
||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
||||
ret_value = self.engine.register_memory(ptr, length)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed.")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
|
||||
Reference in New Issue
Block a user