diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index abb3c9e..745d911 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -424,6 +424,7 @@ class MooncakeStoreConfig: protocol: str device_name: str master_server_address: str + use_ascend_direct: bool @staticmethod def from_file(file_path: str) -> "MooncakeStoreConfig": @@ -436,7 +437,8 @@ class MooncakeStoreConfig: local_buffer_size=config.get("local_buffer_size", 1073741824), protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address")) + master_server_address=config.get("master_server_address"), + use_ascend_direct=config.get("use_ascend_direct", False)) @staticmethod def load_from_env() -> "MooncakeStoreConfig": diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py index dee5101..4472f67 100644 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -142,11 +142,27 @@ class KVCacheStoreSendingThread(KVTransferThread): block_ids = req_meta["block_ids"] req_id = req_meta["req_id"] is_last_chunk = req_meta["is_last_chunk"] - torch.npu.current_stream().synchronize() - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.put(key, addr, size) + if self.m_store.config.use_ascend_direct: + addr_list = [] + size_list = [] + key_list = [] + blockIds = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, block_id = self.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + blockIds.append(block_id) + torch.npu.current_stream().synchronize() + self.m_store.put_batch(key_list, addr_list, size_list, blockIds) + else: + torch.npu.current_stream().synchronize() + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.put(key, addr, size) if is_last_chunk: self.set_finished_request(req_id) self.request_queue.task_done() @@ -173,10 +189,25 @@ class KVCacheStoreRecvingThread(KVTransferThread): mask = req_meta["mask"] block_ids = req_meta["block_ids"] req_id = req_meta["req_id"] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.get(key, addr, size) + if self.m_store.config.use_ascend_direct: + addr_list = [] + size_list = [] + key_list = [] + blockIds = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, block_id = self.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + blockIds.append(block_id) + self.m_store.get_batch(key_list, addr_list, size_list, blockIds) + else: + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.get(key, addr, size) self.set_finished_request(req_id) self.request_queue.task_done() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py index 2383749..bf522f7 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -2,11 +2,13 @@ import os # Third Party +from mooncake.store import ReplicateConfig # type: ignore from vllm.config import ParallelConfig from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.utils import logger +from vllm.utils import get_ip, logger from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey +from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te from .config_data import MooncakeStoreConfig @@ -36,17 +38,28 @@ class Mooncakestore(): assert len(device_ids_list) > tp_rank device_id = device_ids_list[tp_rank] self.config = MooncakeStoreConfig.load_from_env() - if self.config.protocol == "ascend": - local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \ - ":npu_" + str(device_id) - else: - local_hostname = self.config.local_hostname self.store = MooncakeDistributedStore() - ret = self.store.setup(local_hostname, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, self.config.device_name, - self.config.master_server_address) + if self.config.protocol == "ascend" and not self.config.use_ascend_direct: + local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \ + ":npu_" + str(device_id) + ret = self.store.setup(local_hostname, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address) + else: + local_hostname = get_ip() + transfer_engine = get_global_te(local_hostname, device_name=None) + self.local_seg = local_hostname + ":" + str( + transfer_engine.get_rpc_port()) + ret = self.store.setup(self.local_seg, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + transfer_engine.get_engine()) if ret != 0: msg = "Initialize mooncake failed." logger.error(msg) @@ -61,6 +74,31 @@ class Mooncakestore(): def batch_exists(self, keys: list[str]) -> list[bool]: return self.store.batch_is_exist(keys) + def get_batch(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]], block_ids: list[int]): + try: + res = self.store.batch_get_into_multi_buffers( + keys, addrs, sizes, True) + for value in res: + if value < 0: + logger.error(f"Failed to get key {keys},res:{res}") + except Exception as e: + logger.error(f"Failed to get key {keys}. {e}") + + def put_batch(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]], block_ids: list[int]): + try: + config = ReplicateConfig() + config.preferred_segment = self.local_seg + config.prefer_alloc_in_same_node = True + res = self.store.batch_put_from_multi_buffers( + keys, addrs, sizes, config) + for value in res: + if value < 0: + logger.error(f"Failed to put key {keys},res:{res}") + except Exception as e: + logger.error(f"Failed to put key {keys},error:{e}") + def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): expect_res = sum(size) key_str = key.to_string() @@ -85,4 +123,4 @@ class Mooncakestore(): def close(self): self.store.close() - logger.info("Closed the mooncake store connection") \ No newline at end of file + logger.info("Closed the mooncake store connection") diff --git a/vllm_ascend/distributed/mooncake/transfer_engine.py b/vllm_ascend/distributed/mooncake/transfer_engine.py new file mode 100644 index 0000000..e515da6 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/transfer_engine.py @@ -0,0 +1,28 @@ +import threading +from typing import Optional + +from mooncake.engine import TransferEngine # type: ignore + +_global_te = None +_global_te_lock = threading.Lock() + + +def get_global_te(hostname: str, device_name: Optional[str]): + global _global_te + if _global_te is None: + with _global_te_lock: + # Double-Checked Locking + if _global_te is None: + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + transfer_engine = TransferEngine() + device_name = device_name if device_name is not None else "" + ret_value = transfer_engine.initialize(hostname, + "P2PHANDSHAKE", + "ascend", device_name) + if ret_value != 0: + raise RuntimeError( + f"TransferEngine initialization failed with ret_value: {ret_value}" + ) + _global_te = transfer_engine + return _global_te diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 3ed0ea5..23dfb32 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -32,6 +32,7 @@ from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config +from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -879,11 +880,6 @@ class MooncakeConnectorWorker: f"prefill_tp_size: {self._prefill_tp_size} must be greater than" f" or equal to the decode_tp_size: {self._decode_tp_size}") - if TransferEngine is None: - raise RuntimeError("mooncake is not available") - logger.info("Initializing Mooncake work %s", engine_id) - self.engine = TransferEngine() - # Metadata. self.vllm_config = vllm_config self.ascend_config = get_ascend_config() @@ -933,7 +929,8 @@ class MooncakeConnectorWorker: hostname = self.side_channel_host else: hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" - self._initialize(hostname=hostname, device_name=None) + logger.info("Initializing Mooncake work %s", engine_id) + self.engine = get_global_te(hostname, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches.