From 082aa2e5b7cba5b26da8b1bb9955c412ca1c2dd3 Mon Sep 17 00:00:00 2001 From: lty Date: Mon, 2 Feb 2026 16:26:18 +0800 Subject: [PATCH] [Bugfix]The service fails to be started when the memcache pool is enabled (#6229) ### What this PR does / why we need it? The service fails to be started when the memcache pool is enabled without configuring the mooncake path. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? ``` #memcache echo 200000 > /proc/sys/vm/nr_hugepages source /usr/local/memfabric_hybrid/set_env.sh source /usr/local/memcache_hybrid/set_env.sh source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/nnal/atb/set_env.sh export MMC_LOCAL_CONFIG_PATH=/usr/local/memcache_hybrid/latest/config/mmc-local.conf vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \ --host $local_ip \ --port 8002 \ --served-model-name model \ --data-parallel-size 2 \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --no-enable-prefix-caching \ --no-enable-chunked-prefill \ --max-num-seqs 4 \ --max-model-len 8192 \ --max-num-batched-tokens 8192 \ --gpu-memory-utilization 0.9 \ --trust-remote-code \ --enforce-eager \ --quantization ascend \ --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \ --kv-transfer-config \ '{ "kv_connector": "AscendStoreConnector", "kv_role": "kv_both", "kv_connector_extra_config": { "backend": "memcache", "lookup_rpc_port":"0" } }' ``` - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 --------- Signed-off-by: lty --- .../kv_pool/ascend_store/pool_worker.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 4e69a112..033ea34d 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -1,6 +1,7 @@ +import importlib import math import threading -from collections.abc import Callable, Generator +from collections.abc import Generator import torch from vllm.config import VllmConfig @@ -14,9 +15,6 @@ from vllm.distributed import ( from vllm.logger import logger from vllm.v1.core.kv_cache_utils import BlockHash -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( AscendConnectorMetadata, ChunkedTokenDatabase, @@ -32,9 +30,15 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import KVTransferThread, ) -backend_map: dict[str, Callable[..., Backend]] = { - "mooncake": MooncakeBackend, - "memcache": MemcacheBackend, +backend_map = { + "mooncake": { + "name": "MooncakeBackend", + "path": "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend", + }, + "memcache": { + "name": "MemcacheBackend", + "path": "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend", + }, } @@ -125,7 +129,13 @@ class KVPoolWorker: self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions) - real_backend = backend_map.get(self.backend.lower()) + backend = backend_map.get(self.backend.lower()) + assert backend is not None + backend_path = backend.get("path") + backend_name = backend.get("name") + assert backend_path is not None and backend_name is not None + backend_module = importlib.import_module(backend_path) + real_backend = getattr(backend_module, backend_name) # be removed later if self.backend == "mooncake":