Mooncake store use adxl inferface (#3350)

Use adxl inferface in mooncake store, mooncake PR
https://github.com/kvcache-ai/Mooncake/pull/929

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LCAIZJ <leichao139636@163.com>
This commit is contained in:
Chao Lei
2025-10-21 20:18:17 +08:00
committed by GitHub
parent ef3fabf399
commit 11f9bccf6b
5 changed files with 124 additions and 28 deletions

View File

@@ -424,6 +424,7 @@ class MooncakeStoreConfig:
protocol: str
device_name: str
master_server_address: str
use_ascend_direct: bool
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
@@ -436,7 +437,8 @@ class MooncakeStoreConfig:
local_buffer_size=config.get("local_buffer_size", 1073741824),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"))
master_server_address=config.get("master_server_address"),
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":

View File

@@ -142,11 +142,27 @@ class KVCacheStoreSendingThread(KVTransferThread):
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
torch.npu.current_stream().synchronize()
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.put(key, addr, size)
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
torch.npu.current_stream().synchronize()
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
else:
torch.npu.current_stream().synchronize()
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.put(key, addr, size)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
@@ -173,10 +189,25 @@ class KVCacheStoreRecvingThread(KVTransferThread):
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.get(key, addr, size)
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.get(key, addr, size)
self.set_finished_request(req_id)
self.request_queue.task_done()

View File

@@ -2,11 +2,13 @@
import os
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.utils import logger
from vllm.utils import get_ip, logger
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
from .config_data import MooncakeStoreConfig
@@ -36,17 +38,28 @@ class Mooncakestore():
assert len(device_ids_list) > tp_rank
device_id = device_ids_list[tp_rank]
self.config = MooncakeStoreConfig.load_from_env()
if self.config.protocol == "ascend":
local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \
":npu_" + str(device_id)
else:
local_hostname = self.config.local_hostname
self.store = MooncakeDistributedStore()
ret = self.store.setup(local_hostname, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol, self.config.device_name,
self.config.master_server_address)
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
":npu_" + str(device_id)
ret = self.store.setup(local_hostname, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address)
else:
local_hostname = get_ip()
transfer_engine = get_global_te(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
@@ -61,6 +74,31 @@ class Mooncakestore():
def batch_exists(self, keys: list[str]) -> list[bool]:
return self.store.batch_is_exist(keys)
def get_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
res = self.store.batch_get_into_multi_buffers(
keys, addrs, sizes, True)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}. {e}")
def put_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(
keys, addrs, sizes, config)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
expect_res = sum(size)
key_str = key.to_string()
@@ -85,4 +123,4 @@ class Mooncakestore():
def close(self):
self.store.close()
logger.info("Closed the mooncake store connection")
logger.info("Closed the mooncake store connection")

View File

@@ -0,0 +1,28 @@
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
_global_te = None
_global_te_lock = threading.Lock()
def get_global_te(hostname: str, device_name: Optional[str]):
global _global_te
if _global_te is None:
with _global_te_lock:
# Double-Checked Locking
if _global_te is None:
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = transfer_engine.initialize(hostname,
"P2PHANDSHAKE",
"ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
_global_te = transfer_engine
return _global_te

View File

@@ -32,6 +32,7 @@ from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -879,11 +880,6 @@ class MooncakeConnectorWorker:
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
f" or equal to the decode_tp_size: {self._decode_tp_size}")
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = TransferEngine()
# Metadata.
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
@@ -933,7 +929,8 @@ class MooncakeConnectorWorker:
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
self._initialize(hostname=hostname, device_name=None)
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = get_global_te(hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches.