diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 4e69a112..033ea34d 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -1,6 +1,7 @@ +import importlib import math import threading -from collections.abc import Callable, Generator +from collections.abc import Generator import torch from vllm.config import VllmConfig @@ -14,9 +15,6 @@ from vllm.distributed import ( from vllm.logger import logger from vllm.v1.core.kv_cache_utils import BlockHash -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( AscendConnectorMetadata, ChunkedTokenDatabase, @@ -32,9 +30,15 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import KVTransferThread, ) -backend_map: dict[str, Callable[..., Backend]] = { - "mooncake": MooncakeBackend, - "memcache": MemcacheBackend, +backend_map = { + "mooncake": { + "name": "MooncakeBackend", + "path": "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend", + }, + "memcache": { + "name": "MemcacheBackend", + "path": "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend", + }, } @@ -125,7 +129,13 @@ class KVPoolWorker: self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions) - real_backend = backend_map.get(self.backend.lower()) + backend = backend_map.get(self.backend.lower()) + assert backend is not None + backend_path = backend.get("path") + backend_name = backend.get("name") + assert backend_path is not None and backend_name is not None + backend_module = importlib.import_module(backend_path) + real_backend = getattr(backend_module, backend_name) # be removed later if self.backend == "mooncake":