【main】ADXL/HIXL supports FabricMem Mode (#6806)

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2026-03-05 21:04:11 +08:00
committed by GitHub
parent 50441e4650
commit ae394767d4
6 changed files with 46 additions and 40 deletions

View File

@@ -7,16 +7,14 @@ from dataclasses import dataclass
import torch
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_GLOBAL_SEGMENT_SIZE = 1073741824 # 1.0 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
@@ -35,18 +33,34 @@ class MooncakeBackend(Backend):
self.rank = parallel_config.rank
if self.config.protocol == "ascend":
local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port())
ret = self.store.setup(
self.local_seg,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine(),
)
# ASCEND_ENABLE_USE_FABRIC_MEM: Enable unified memory address direct transmission scheme
# and only can be used for 800 I/T A3 series.
# Required supporting hardware versions are as follows:
if os.getenv("ASCEND_ENABLE_USE_FABRIC_MEM", "0") != "1":
transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port())
ret = self.store.setup(
self.local_seg,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine(),
)
else:
self.local_seg = local_hostname
ret = self.store.setup(
self.local_seg,
self.config.metadata_server,
self.config.global_segment_size,
0,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
)
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
@@ -57,21 +71,15 @@ class MooncakeBackend(Backend):
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], lengths: list[int]):
global_te.register_buffer(ptrs, lengths)
if os.getenv("ASCEND_ENABLE_USE_FABRIC_MEM", "0") != "1":
global_te.register_buffer(ptrs, lengths)
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
try:
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
else:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config)
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
@@ -80,11 +88,7 @@ class MooncakeBackend(Backend):
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
try:
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
else:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True)
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}")

View File

@@ -30,7 +30,6 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import
KVCacheStoreSendingThread,
KVTransferThread,
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
backend_map = {
"mooncake": {
@@ -98,12 +97,6 @@ class KVPoolWorker:
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
soc_version = get_ascend_device_type()
# be removed later
if self.backend == "mooncake" and soc_version in {AscendDeviceType.A3}:
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.metadata = KeyMetadata(
model_config.model.rstrip("/").split("/")[-1],
self.head_or_tp_rank,