Mooncake store use adxl inferface (#3350)
Use adxl inferface in mooncake store, mooncake PR https://github.com/kvcache-ai/Mooncake/pull/929 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LCAIZJ <leichao139636@163.com>
This commit is contained in:
@@ -424,6 +424,7 @@ class MooncakeStoreConfig:
|
|||||||
protocol: str
|
protocol: str
|
||||||
device_name: str
|
device_name: str
|
||||||
master_server_address: str
|
master_server_address: str
|
||||||
|
use_ascend_direct: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||||
@@ -436,7 +437,8 @@ class MooncakeStoreConfig:
|
|||||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
||||||
protocol=config.get("protocol", "tcp"),
|
protocol=config.get("protocol", "tcp"),
|
||||||
device_name=config.get("device_name", ""),
|
device_name=config.get("device_name", ""),
|
||||||
master_server_address=config.get("master_server_address"))
|
master_server_address=config.get("master_server_address"),
|
||||||
|
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_from_env() -> "MooncakeStoreConfig":
|
def load_from_env() -> "MooncakeStoreConfig":
|
||||||
|
|||||||
@@ -142,11 +142,27 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
block_ids = req_meta["block_ids"]
|
block_ids = req_meta["block_ids"]
|
||||||
req_id = req_meta["req_id"]
|
req_id = req_meta["req_id"]
|
||||||
is_last_chunk = req_meta["is_last_chunk"]
|
is_last_chunk = req_meta["is_last_chunk"]
|
||||||
torch.npu.current_stream().synchronize()
|
if self.m_store.config.use_ascend_direct:
|
||||||
for start, end, key in self.token_database.process_tokens(
|
addr_list = []
|
||||||
tokens, mask):
|
size_list = []
|
||||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
key_list = []
|
||||||
self.m_store.put(key, addr, size)
|
blockIds = []
|
||||||
|
for start, end, key in self.token_database.process_tokens(
|
||||||
|
tokens, mask):
|
||||||
|
addr, size, block_id = self.prepare_value(
|
||||||
|
start, end, block_ids)
|
||||||
|
key_list.append(key.to_string())
|
||||||
|
addr_list.append(addr)
|
||||||
|
size_list.append(size)
|
||||||
|
blockIds.append(block_id)
|
||||||
|
torch.npu.current_stream().synchronize()
|
||||||
|
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
|
||||||
|
else:
|
||||||
|
torch.npu.current_stream().synchronize()
|
||||||
|
for start, end, key in self.token_database.process_tokens(
|
||||||
|
tokens, mask):
|
||||||
|
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||||
|
self.m_store.put(key, addr, size)
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
self.set_finished_request(req_id)
|
self.set_finished_request(req_id)
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
@@ -173,10 +189,25 @@ class KVCacheStoreRecvingThread(KVTransferThread):
|
|||||||
mask = req_meta["mask"]
|
mask = req_meta["mask"]
|
||||||
block_ids = req_meta["block_ids"]
|
block_ids = req_meta["block_ids"]
|
||||||
req_id = req_meta["req_id"]
|
req_id = req_meta["req_id"]
|
||||||
for start, end, key in self.token_database.process_tokens(
|
if self.m_store.config.use_ascend_direct:
|
||||||
tokens, mask):
|
addr_list = []
|
||||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
size_list = []
|
||||||
self.m_store.get(key, addr, size)
|
key_list = []
|
||||||
|
blockIds = []
|
||||||
|
for start, end, key in self.token_database.process_tokens(
|
||||||
|
tokens, mask):
|
||||||
|
addr, size, block_id = self.prepare_value(
|
||||||
|
start, end, block_ids)
|
||||||
|
key_list.append(key.to_string())
|
||||||
|
addr_list.append(addr)
|
||||||
|
size_list.append(size)
|
||||||
|
blockIds.append(block_id)
|
||||||
|
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
|
||||||
|
else:
|
||||||
|
for start, end, key in self.token_database.process_tokens(
|
||||||
|
tokens, mask):
|
||||||
|
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||||
|
self.m_store.get(key, addr, size)
|
||||||
self.set_finished_request(req_id)
|
self.set_finished_request(req_id)
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,13 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# Third Party
|
# Third Party
|
||||||
|
from mooncake.store import ReplicateConfig # type: ignore
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||||
from vllm.utils import logger
|
from vllm.utils import get_ip, logger
|
||||||
|
|
||||||
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
|
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
|
||||||
|
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||||
|
|
||||||
from .config_data import MooncakeStoreConfig
|
from .config_data import MooncakeStoreConfig
|
||||||
|
|
||||||
@@ -36,17 +38,28 @@ class Mooncakestore():
|
|||||||
assert len(device_ids_list) > tp_rank
|
assert len(device_ids_list) > tp_rank
|
||||||
device_id = device_ids_list[tp_rank]
|
device_id = device_ids_list[tp_rank]
|
||||||
self.config = MooncakeStoreConfig.load_from_env()
|
self.config = MooncakeStoreConfig.load_from_env()
|
||||||
if self.config.protocol == "ascend":
|
|
||||||
local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \
|
|
||||||
":npu_" + str(device_id)
|
|
||||||
else:
|
|
||||||
local_hostname = self.config.local_hostname
|
|
||||||
self.store = MooncakeDistributedStore()
|
self.store = MooncakeDistributedStore()
|
||||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
|
||||||
self.config.global_segment_size,
|
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
|
||||||
self.config.local_buffer_size,
|
":npu_" + str(device_id)
|
||||||
self.config.protocol, self.config.device_name,
|
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||||
self.config.master_server_address)
|
self.config.global_segment_size,
|
||||||
|
self.config.local_buffer_size,
|
||||||
|
self.config.protocol,
|
||||||
|
self.config.device_name,
|
||||||
|
self.config.master_server_address)
|
||||||
|
else:
|
||||||
|
local_hostname = get_ip()
|
||||||
|
transfer_engine = get_global_te(local_hostname, device_name=None)
|
||||||
|
self.local_seg = local_hostname + ":" + str(
|
||||||
|
transfer_engine.get_rpc_port())
|
||||||
|
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||||
|
self.config.global_segment_size,
|
||||||
|
self.config.local_buffer_size,
|
||||||
|
self.config.protocol,
|
||||||
|
self.config.device_name,
|
||||||
|
self.config.master_server_address,
|
||||||
|
transfer_engine.get_engine())
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
msg = "Initialize mooncake failed."
|
msg = "Initialize mooncake failed."
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
@@ -61,6 +74,31 @@ class Mooncakestore():
|
|||||||
def batch_exists(self, keys: list[str]) -> list[bool]:
|
def batch_exists(self, keys: list[str]) -> list[bool]:
|
||||||
return self.store.batch_is_exist(keys)
|
return self.store.batch_is_exist(keys)
|
||||||
|
|
||||||
|
def get_batch(self, keys: list[str], addrs: list[list[int]],
|
||||||
|
sizes: list[list[int]], block_ids: list[int]):
|
||||||
|
try:
|
||||||
|
res = self.store.batch_get_into_multi_buffers(
|
||||||
|
keys, addrs, sizes, True)
|
||||||
|
for value in res:
|
||||||
|
if value < 0:
|
||||||
|
logger.error(f"Failed to get key {keys},res:{res}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get key {keys}. {e}")
|
||||||
|
|
||||||
|
def put_batch(self, keys: list[str], addrs: list[list[int]],
|
||||||
|
sizes: list[list[int]], block_ids: list[int]):
|
||||||
|
try:
|
||||||
|
config = ReplicateConfig()
|
||||||
|
config.preferred_segment = self.local_seg
|
||||||
|
config.prefer_alloc_in_same_node = True
|
||||||
|
res = self.store.batch_put_from_multi_buffers(
|
||||||
|
keys, addrs, sizes, config)
|
||||||
|
for value in res:
|
||||||
|
if value < 0:
|
||||||
|
logger.error(f"Failed to put key {keys},res:{res}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to put key {keys},error:{e}")
|
||||||
|
|
||||||
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||||
expect_res = sum(size)
|
expect_res = sum(size)
|
||||||
key_str = key.to_string()
|
key_str = key.to_string()
|
||||||
|
|||||||
28
vllm_ascend/distributed/mooncake/transfer_engine.py
Normal file
28
vllm_ascend/distributed/mooncake/transfer_engine.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mooncake.engine import TransferEngine # type: ignore
|
||||||
|
|
||||||
|
_global_te = None
|
||||||
|
_global_te_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_te(hostname: str, device_name: Optional[str]):
|
||||||
|
global _global_te
|
||||||
|
if _global_te is None:
|
||||||
|
with _global_te_lock:
|
||||||
|
# Double-Checked Locking
|
||||||
|
if _global_te is None:
|
||||||
|
if TransferEngine is None:
|
||||||
|
raise RuntimeError("mooncake is not available")
|
||||||
|
transfer_engine = TransferEngine()
|
||||||
|
device_name = device_name if device_name is not None else ""
|
||||||
|
ret_value = transfer_engine.initialize(hostname,
|
||||||
|
"P2PHANDSHAKE",
|
||||||
|
"ascend", device_name)
|
||||||
|
if ret_value != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||||
|
)
|
||||||
|
_global_te = transfer_engine
|
||||||
|
return _global_te
|
||||||
@@ -32,6 +32,7 @@ from vllm.v1.request import RequestStatus
|
|||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||||
|
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
@@ -879,11 +880,6 @@ class MooncakeConnectorWorker:
|
|||||||
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
|
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
|
||||||
f" or equal to the decode_tp_size: {self._decode_tp_size}")
|
f" or equal to the decode_tp_size: {self._decode_tp_size}")
|
||||||
|
|
||||||
if TransferEngine is None:
|
|
||||||
raise RuntimeError("mooncake is not available")
|
|
||||||
logger.info("Initializing Mooncake work %s", engine_id)
|
|
||||||
self.engine = TransferEngine()
|
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.ascend_config = get_ascend_config()
|
self.ascend_config = get_ascend_config()
|
||||||
@@ -933,7 +929,8 @@ class MooncakeConnectorWorker:
|
|||||||
hostname = self.side_channel_host
|
hostname = self.side_channel_host
|
||||||
else:
|
else:
|
||||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
||||||
self._initialize(hostname=hostname, device_name=None)
|
logger.info("Initializing Mooncake work %s", engine_id)
|
||||||
|
self.engine = get_global_te(hostname, device_name=None)
|
||||||
self.te_rpc_port = self.engine.get_rpc_port()
|
self.te_rpc_port = self.engine.get_rpc_port()
|
||||||
|
|
||||||
# Background thread for sending or receiving KV caches.
|
# Background thread for sending or receiving KV caches.
|
||||||
|
|||||||
Reference in New Issue
Block a user