Files
xc-llm-ascend/vllm_ascend/distributed/kvpool/backend/memcache_backend.py
fems14 b662d914a4 [bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030)
### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: fems14 <1804143737@qq.com>
2025-12-15 21:56:05 +08:00

72 lines
2.4 KiB
Python

# Standard
from enum import Enum
import torch
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm_ascend.distributed.kvpool.backend.backend import Backend
class MmcDirect(Enum):
COPY_L2G = 0
COPY_G2L = 1
COPY_G2H = 2
COPY_H2G = 3
class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from memcache_hybrid import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
try:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def set_device(self):
device = torch.device(f"npu:{self.rank}")
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], sizes: list[int]):
pass
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {key},error:{e}")