Files
xc-llm-ascend/vllm_ascend/distributed/kvpool/backend/memcache_backend.py
DreamerLeader db7cf9b0ca [bugfix] A2 Environment Pooling for Memcache Compatibility (#5601)
### What this PR does / why we need it?
When running memcache in the A2 environment, the logic for registering
memory needs to be added. Additionally, there is a link establishment
conflict between memcache and HCCS during initialization in A2, so the
link should be established in advance.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
7157596103

---------

Signed-off-by: fangjianwei <f30058701@china.huawei.com>
Co-authored-by: fangjianwei <f30058701@china.huawei.com>
2026-01-13 09:07:38 +08:00

96 lines
3.5 KiB
Python

# Standard
from enum import Enum
import torch
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
class MmcDirect(Enum):
COPY_L2G = 0
COPY_G2L = 1
COPY_G2H = 2
COPY_H2G = 3
class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from memcache_hybrid import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
try:
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
import torch
from vllm.distributed import get_world_group
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
else:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def set_device(self):
device = torch.device(f"npu:{self.rank}")
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], sizes: list[int]):
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
for ptr, size in zip(ptrs, sizes):
self.store.register_buffer(ptr, size)
else:
pass
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {key},error:{e}")