From 531d0e6fff8331224ef76aeea0f7614fcdfd8abc Mon Sep 17 00:00:00 2001 From: DreamerLeader <88812830+DreamerLeader@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:55:56 +0800 Subject: [PATCH] =?UTF-8?q?[v0.18.0][BugFix][KV=20Pool]Fix=20the=20conflic?= =?UTF-8?q?t=20between=20pooling=20scenarios=20=E2=80=A6=20(#8101)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …and PCP across machines ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: DreamLeader <2270923832@qq.com> --- .../ascend_store/backend/memcache_backend.py | 19 ++++++------------- .../ascend_store/backend/mooncake_backend.py | 5 +++-- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py index fc5bc070..a04c59f8 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py @@ -3,6 +3,7 @@ from enum import Enum import torch from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_world_group from vllm.logger import logger from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend @@ -29,21 +30,13 @@ class MemcacheBackend(Backend): try: soc_version = get_ascend_device_type() if soc_version in {AscendDeviceType.A2}: - import torch - from vllm.distributed import get_world_group - tmp_tensor = torch.zeros(1, device="npu") output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group) - self.rank = parallel_config.rank - self.store = DistributedObjectStore() - res = self.store.init(self.rank) - assert res == 0 - else: - self.rank = parallel_config.rank - self.store = DistributedObjectStore() - res = self.store.init(self.rank) - assert res == 0 + self.local_rank = get_world_group().local_rank + self.store = DistributedObjectStore() + res = self.store.init(self.local_rank) + assert res == 0 except ValueError as e: logger.error("Configuration loading failed: %s", e) raise @@ -52,7 +45,7 @@ class MemcacheBackend(Backend): raise def set_device(self): - device = torch.device(f"npu:{self.rank}") + device = torch.device(f"npu:{self.local_rank}") torch.npu.set_device(device) def register_buffer(self, ptrs: list[int], sizes: list[int]): diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py index 078590f6..1e584aa4 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py @@ -8,6 +8,7 @@ import torch # Third Party from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_world_group from vllm.logger import logger from vllm.utils.network_utils import get_ip @@ -30,7 +31,6 @@ class MooncakeBackend(Backend): ) from e self.config = MooncakeStoreConfig.load_from_env() self.store = MooncakeDistributedStore() - self.rank = parallel_config.rank if self.config.protocol == "ascend": local_hostname = get_ip() # ASCEND_ENABLE_USE_FABRIC_MEM: Enable unified memory address direct transmission scheme @@ -67,7 +67,8 @@ class MooncakeBackend(Backend): raise RuntimeError(msg) def set_device(self): - device = torch.device(f"npu:{self.rank}") + local_rank = get_world_group().local_rank + device = torch.device(f"npu:{local_rank}") torch.npu.set_device(device) def register_buffer(self, ptrs: list[int], lengths: list[int]):