From 48e10de8c95d4971406b9a6c5bf81944f714eb89 Mon Sep 17 00:00:00 2001 From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com> Date: Sat, 17 Jan 2026 11:50:13 +0800 Subject: [PATCH] [Bugfix] fix cpu offload hang with tp=1 (#5963) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? As issue #5948 reported,when using cpu_offload_connector with TP=1, the server will hang on starting, we found several bugs here to fix. 1. some crash error encountered because of code changed with vllm version updating, some of them can be fixed as #5948, and this PR fixed all of them. 2. hang problem described in #5948, the direct reason is that in cpu_offload_connector, RPC client using the same client id in scheduler and worker when tensor_parrallel_size is 1, this PR force the client id to be different, then it is fixed. - Why we didn't find this hang problem before? Because we using --distributed-executor-backend mp or tensor_parrallel_size > 1 in our test, in our old test case, the scheduler and workers are different procceses, then client ids build by `worker-{os.getpid()}` are not the same. But when using tensor_parrallel_size=1, vllm will use uniproc as distributed-executor-backend by default, the scheduler and worker will by in the same proccess, then client ids are the same and hang. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 Signed-off-by: lidenghui --- .../kv_pool/cpu_offload/cpu_kv_cache_manager.py | 8 ++++---- .../kv_pool/cpu_offload/cpu_offload_connector.py | 3 ++- .../kv_transfer/kv_pool/cpu_offload/metadata.py | 4 +++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py index 5f838016..24307f5f 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py @@ -5,12 +5,11 @@ from typing import Optional from vllm.logger import logger from vllm.utils.hashing import sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - PrefixCachingMetrics) +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock) from vllm.v1.core.single_type_kv_cache_manager import \ get_manager_for_kv_cache_spec from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics) from vllm.v1.request import Request @@ -20,7 +19,7 @@ class CPUCacheStats: self.enable_prefix_caching = enable_prefix_caching self.log_stats = log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.cpu_prefix_cache_metrics = PrefixCachingMetrics() + self.cpu_prefix_cache_metrics = CachingMetrics() self.time_sec = int(time.time()) def log(self): @@ -111,6 +110,7 @@ class CPUKVCacheManager: block_pool=self.block_pool, kv_cache_spec=self.single_type_manager.kv_cache_spec, use_eagle=self.use_eagle, + alignment_tokens=self.block_size, ) num_computed_tokens = len(computed_blocks[0]) * self.block_size self.req_to_computed_blocks[request_id] = computed_blocks[0] diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py index 60128eb6..2ea31953 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py @@ -75,7 +75,8 @@ class CPUOffloadingConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, - kv_cache_config: Optional[KVCacheConfig] = None): + kv_cache_config: Optional["KVCacheConfig"] = None): + self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set()) if not vllm_config.cache_config.enable_prefix_caching: self.connector_scheduler: Optional[ CPUOffloadingConnectorScheduler] = None diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py index 31ff509f..ab5bc08c 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py @@ -45,7 +45,9 @@ class MetadataServer: class ZMQRPCClient: - def __init__(self, identity=f"worker-{os.getpid()}"): + def __init__(self, identity=None): + if identity is None: + identity = f"worker-{os.getpid()}-{id(self)}" logger.info(f"metadata client for worker {identity} started") self.ctx = zmq.Context() # type: ignore self.socket = make_zmq_socket(