### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: fems14 <1804143737@qq.com>
72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
# Standard
|
|
from enum import Enum
|
|
|
|
import torch
|
|
from vllm.config import ParallelConfig
|
|
from vllm.logger import logger
|
|
|
|
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
|
|
|
|
|
class MmcDirect(Enum):
|
|
COPY_L2G = 0
|
|
COPY_G2L = 1
|
|
COPY_G2H = 2
|
|
COPY_H2G = 3
|
|
|
|
|
|
class MemcacheBackend(Backend):
|
|
|
|
def __init__(self, parallel_config: ParallelConfig):
|
|
try:
|
|
from memcache_hybrid import DistributedObjectStore # type: ignore
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install memcache by following the instructions at "
|
|
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
|
"to run vLLM with MemcacheConnector.") from e
|
|
try:
|
|
self.rank = parallel_config.rank
|
|
self.store = DistributedObjectStore()
|
|
res = self.store.init(self.rank)
|
|
assert res == 0
|
|
except ValueError as e:
|
|
logger.error("Configuration loading failed: %s", e)
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(
|
|
"An error occurred while loading the configuration: %s", exc)
|
|
raise
|
|
|
|
def set_device(self):
|
|
device = torch.device(f"npu:{self.rank}")
|
|
torch.npu.set_device(device)
|
|
|
|
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
|
pass
|
|
|
|
def exists(self, keys: list[str]) -> list[int]:
|
|
return self.store.batch_is_exist(keys)
|
|
|
|
def get(self, key: list[str], addr: list[list[int]],
|
|
size: list[list[int]]):
|
|
try:
|
|
res = self.store.batch_get_into_layers(key, addr, size,
|
|
MmcDirect.COPY_G2L.value)
|
|
for value in res:
|
|
if value != 0:
|
|
logger.error(f"Failed to get key {key},res:{res}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to get key {key}. {e}")
|
|
|
|
def put(self, key: list[str], addr: list[list[int]],
|
|
size: list[list[int]]):
|
|
try:
|
|
res = self.store.batch_put_from_layers(key, addr, size,
|
|
MmcDirect.COPY_L2G.value)
|
|
for value in res:
|
|
if value != 0:
|
|
logger.error(f"Failed to get key {key},res:{res}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to put key {key},error:{e}")
|