diff --git a/tests/ut/distributed/mooncake/test_config_data.py b/tests/ut/distributed/mooncake/test_config_data.py index bd8d0793..57f09f07 100644 --- a/tests/ut/distributed/mooncake/test_config_data.py +++ b/tests/ut/distributed/mooncake/test_config_data.py @@ -6,6 +6,9 @@ from unittest.mock import MagicMock fake_engine = types.ModuleType("mooncake.engine") fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] sys.modules["mooncake.engine"] = fake_engine +fake_store = types.ModuleType("mooncake.store") +fake_store.ReplicateConfig = MagicMock() # type: ignore[attr-defined] +sys.modules["mooncake.store"] = fake_store from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402 _convert_to_bytes, _parse_global_segment_size) diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py index 753806c3..f1137612 100644 --- a/vllm_ascend/distributed/kvpool/ascend_store_connector.py +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -136,21 +136,9 @@ class AscendStoreConnector(KVConnectorBase_V1): finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None - meta = self._get_connector_metadata() done_sending, done_recving = self.connector_worker.get_finished( finished_req_ids) - sended_and_finished: set[str] = set() - for item in list(self.sended_but_unfinished_reqs): - if item not in meta.unfinished_request_ids: - sended_and_finished.add(item) - self.sended_but_unfinished_reqs.remove(item) - for item in done_sending: - if item in meta.unfinished_request_ids: - self.sended_but_unfinished_reqs.add(item) - else: - sended_and_finished.add(item) - - return sended_and_finished, done_recving + return done_sending, done_recving class LookupKeyServer: diff --git a/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py index 699c8481..53426971 100644 --- a/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py +++ b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Union # Third Party +from mooncake.store import ReplicateConfig # type: ignore from vllm.config import ParallelConfig from vllm.logger import logger from vllm.utils.network_utils import get_ip @@ -56,7 +57,11 @@ class MooncakeBackend(Backend): def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): try: - res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes) + config = ReplicateConfig() + config.preferred_segment = self.local_seg + config.prefer_alloc_in_same_node = True + res = self.store.batch_put_from_multi_buffers( + keys, addrs, sizes, config) for value in res: if value < 0: logger.error(f"Failed to put key {keys},res:{res}") @@ -66,7 +71,8 @@ class MooncakeBackend(Backend): def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): try: - res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes) + res = self.store.batch_get_into_multi_buffers( + keys, addrs, sizes, True) for value in res: if value < 0: logger.error(f"Failed to get key {keys}, res:{res}") diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index 8800b5f5..4160f7ff 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -223,6 +223,8 @@ class LoadSpec: # Whether the scheduler allow us to load the tokens can_load: bool + token_len: int = 0 + @dataclass class RequestTracker: diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index bfb6eba0..4b2061d8 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -125,7 +125,6 @@ class KVCacheStoreSendingThread(KVTransferThread): token_len = req_meta.token_len_chunk block_ids = req_meta.block_ids req_id = req_meta.req_id - is_last_chunk = req_meta.is_last_chunk current_event = req_meta.current_event starts = [] ends = [] @@ -142,15 +141,15 @@ class KVCacheStoreSendingThread(KVTransferThread): keys = keys[self.tp_rank % self.put_step::self.put_step] if not keys: - if is_last_chunk: - self.set_finished_request(req_id) + with self.done_task_lock: + self.stored_requests[req_id] -= 1 return skip_block_num = self.lookup(keys) if skip_block_num == len(keys): - if is_last_chunk: - self.set_finished_request(req_id) + with self.done_task_lock: + self.stored_requests[req_id] -= 1 return starts = starts[skip_block_num:] @@ -208,6 +207,7 @@ class KVCacheStoreRecvingThread(KVTransferThread): name="KVCacheStoreRecvingThread") def _handle_request(self, req_meta: ReqMeta): + token_len = req_meta.load_spec.token_len # type: ignore[union-attr] req_id = req_meta.req_id mask_num = ( req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr] @@ -216,7 +216,7 @@ class KVCacheStoreRecvingThread(KVTransferThread): size_list = [] key_list = [] for start, end, key in self.token_database.process_tokens( - req_meta.token_len_chunk, req_meta.block_hashes, mask_num): + token_len, req_meta.block_hashes, mask_num): addr, size, _ = self.token_database.prepare_value( start, end, req_meta.block_ids) key_list.append(key.to_string()) diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 863ee2bc..c5907894 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -134,6 +134,12 @@ class KVPoolWorker: self.use_mla, partitions) real_backend = backend_map.get(self.backend.lower()) + + # be removed later + if self.backend == "mooncake": + self.head_or_tp_rank = self.tp_rank + self.put_step = 1 + self.m_store = real_backend( # type: ignore[misc] parallel_config) @@ -245,7 +251,7 @@ class KVPoolWorker: token_len = request.load_spec.kvpool_cached_tokens + 1 else: token_len = request.load_spec.kvpool_cached_tokens - request.token_len_chunk = token_len + request.load_spec.token_len = token_len if self.use_layerwise: layerwise_retriever = self.retrieve_layer(request) next(layerwise_retriever) # first layer load