diff --git a/tests/ut/distributed/mooncake/test_kv_transfer.py b/tests/ut/distributed/mooncake/test_kv_transfer.py new file mode 100644 index 00000000..d0b31335 --- /dev/null +++ b/tests/ut/distributed/mooncake/test_kv_transfer.py @@ -0,0 +1,127 @@ +import threading +import unittest +from types import SimpleNamespace + +import torch + +if not hasattr(torch, "npu"): + torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined] + +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( + LayerMultiBlockReqMeta, + ReqMeta, +) +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( + KVCacheStoreLayerSendingThread, + KVCacheStoreSendingThread, +) + + +class _FakeKey: + def __init__(self, value: str): + self._value = value + + def to_string(self) -> str: + return self._value + + +class _FakeStore: + def __init__(self, exists_result: list[int]): + self.exists_result = exists_result + self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = [] + + def set_device(self): + return None + + def exists(self, keys: list[str]) -> list[int]: + # Return exact number of states for requested keys. + return self.exists_result[: len(keys)] + + def put(self, keys, addrs, sizes): + self.put_calls.append((list(keys), list(addrs), list(sizes))) + + +class _FakeTokenDatabase: + def process_tokens(self, token_len, block_hashes): + for i, _ in enumerate(block_hashes): + yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}") + + def prepare_value(self, start, end, block_ids): + block_id = start // 16 + return [1000 + block_id], [end - start], block_id + + def prepare_value_layer(self, start, end, block_ids, layer_id): + block_id = start // 16 + return [2000 + layer_id * 100 + block_id], [end - start] + + +class TestKVTransferMissingKeyPut(unittest.TestCase): + def test_sending_thread_only_puts_missing_keys(self): + store = _FakeStore(exists_result=[1, 0, 1, 0]) + token_db = _FakeTokenDatabase() + thread = KVCacheStoreSendingThread( + m_store=store, + token_database=token_db, + block_size=16, + tp_rank=0, + dcp_size=1, + put_step=1, + kv_role="kv_producer", + ready_event=threading.Event(), + enable_kv_event=False, + ) + + req_meta = ReqMeta( + req_id="req-1", + token_len_chunk=64, + block_ids=[0, 1, 2, 3], + block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type] + current_event=None, + ) + thread.add_stored_request("req-1") + thread.request_queue.put(req_meta) + thread._handle_request(req_meta) + + self.assertEqual(len(store.put_calls), 1) + put_keys, put_addrs, put_sizes = store.put_calls[0] + self.assertEqual(put_keys, ["k1", "k3"]) + self.assertEqual(put_addrs, [[1001], [1003]]) + self.assertEqual(put_sizes, [[16], [16]]) + + def test_layer_sending_thread_only_puts_missing_keys(self): + store = _FakeStore(exists_result=[1, 0, 1, 0]) + token_db = _FakeTokenDatabase() + thread = KVCacheStoreLayerSendingThread( + m_store=store, + token_database=token_db, + block_size=16, + tp_rank=0, + dcp_size=1, + put_step=1, + ready_event=threading.Event(), + num_layers=2, + enable_kv_event=False, + ) + + req_meta = LayerMultiBlockReqMeta( + req_id="req-2", + keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type] + starts=[0, 16, 32, 48], + ends=[16, 32, 48, 64], + block_ids=[0, 1, 2, 3], + layer_id=1, + is_last_chunk=False, + current_event=None, + ) + thread.request_queue.put(req_meta) + thread._handle_request(req_meta) + + self.assertEqual(len(store.put_calls), 1) + put_keys, put_addrs, put_sizes = store.put_calls[0] + self.assertEqual(put_keys, ["k1", "k3"]) + self.assertEqual(put_addrs, [[2101], [2103]]) + self.assertEqual(put_sizes, [[16], [16]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py index a7058cbb..2cde12d3 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py @@ -398,7 +398,7 @@ class AscendConnectorMetadata(KVConnectorMetadata): @dataclass -class LasyerMultiBlockReqMeta: +class LayerMultiBlockReqMeta: req_id: str keys: list[LayerPoolKey] starts: list[int] diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py index a6df8c39..a1cf7c53 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py @@ -14,7 +14,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend im # isort: off from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( ChunkedTokenDatabase, - LasyerMultiBlockReqMeta, + LayerMultiBlockReqMeta, ReqMeta, ) # isort: on @@ -48,7 +48,7 @@ class KVTransferThread(threading.Thread): def add_request( self, - request: ReqMeta | LasyerMultiBlockReqMeta, + request: ReqMeta | LayerMultiBlockReqMeta, ) -> torch.Tensor: self.request_queue.put(request) @@ -88,22 +88,20 @@ class KVTransferThread(threading.Thread): def lookup( self, keys: list[str], - ) -> int: + ) -> list[bool]: """ - Checks the existence of KV cache of the tokens from the cache engine. - :param tokens: the input tokens, with shape [seq_len] - :return: An int indicating how many prefix tokens are cached. + Check the existence of all keys from the cache engine. + :return: A bool list where True means the key exists in store. """ try: res = self.m_store.exists(keys) # type: ignore[assignment] + exists_list = [False] * len(keys) for index, value in enumerate(res): # type: ignore[arg-type] - if value != 1: - return index - # all tokens where found, return the maximal end + exists_list[index] = value == 1 + return exists_list except Exception as e: logger.error(f"Remote connection failed in contains: {e}") - return 0 - return len(keys) + return [False] * len(keys) def update_kv_event(self, event: list[BlockStored]): with self.kv_event_lock: @@ -159,39 +157,44 @@ class KVCacheStoreSendingThread(KVTransferThread): starts = [] ends = [] keys = [] + block_hashes = [] if req_id not in self.stored_requests: self.request_queue.task_done() return - for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes): + for index, (start, end, key) in enumerate(self.token_database.process_tokens(token_len, req_meta.block_hashes)): starts.append(start) ends.append(end) keys.append(key.to_string()) + block_hashes.append(req_meta.block_hashes[index]) if not self.dcp_size > 1: starts = starts[self.tp_rank % self.put_step :: self.put_step] ends = ends[self.tp_rank % self.put_step :: self.put_step] keys = keys[self.tp_rank % self.put_step :: self.put_step] + block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step] if not keys: self.dec_stored_request(req_id) return - skip_block_num = self.lookup(keys) + exists_states = self.lookup(keys) + missing_indices = [index for index, exists in enumerate(exists_states) if not exists] - if skip_block_num == len(keys): + if not missing_indices: self.dec_stored_request(req_id) return - starts = starts[skip_block_num:] - ends = ends[skip_block_num:] - keys = keys[skip_block_num:] + starts = [starts[index] for index in missing_indices] + ends = [ends[index] for index in missing_indices] + keys = [keys[index] for index in missing_indices] + block_hashes = [block_hashes[index] for index in missing_indices] logger.debug( - "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", + "Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s", len(keys), token_len // self.block_size, - skip_block_num, + len(missing_indices), req_id, ) @@ -206,7 +209,7 @@ class KVCacheStoreSendingThread(KVTransferThread): sizes = [] stored_events: list[BlockStored] = [] prev_key = None - new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]] + new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes] for index, start in enumerate(starts): addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids) addrs.append(addr) @@ -307,7 +310,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): self.request_queue.put(req_meta) def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta + self, req_meta: LayerMultiBlockReqMeta ): starts = req_meta.starts ends = req_meta.ends @@ -330,16 +333,17 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): for key in keys: key_list.append(key.to_string()) - skip_block_num = self.lookup(key_list) + exists_states = self.lookup(key_list) + missing_indices = [index for index, exists in enumerate(exists_states) if not exists] - if skip_block_num == len(key_list): + if not missing_indices: if is_last_chunk and layer_id == self.final_layer_id: self.set_finished_request(req_meta.req_id) return - starts = starts[skip_block_num:] - ends = ends[skip_block_num:] - key_list = key_list[skip_block_num:] + starts = [starts[index] for index in missing_indices] + ends = [ends[index] for index in missing_indices] + key_list = [key_list[index] for index in missing_indices] addr_list = [] size_list = [] @@ -359,10 +363,10 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): self.request_queue.task_done() logger.info( - "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", - len(keys), + "Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s", + len(key_list), total_block, - skip_block_num, + len(missing_indices), req_meta.req_id, ) @@ -384,12 +388,12 @@ class KVCacheStoreLayerRecvingThread(KVTransferThread): self.get_event = get_event def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta + self, req_meta: LayerMultiBlockReqMeta ) -> torch.Tensor: self.request_queue.put(req_meta) def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta + self, req_meta: LayerMultiBlockReqMeta ): addr_list = [] size_list = [] diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 4e6cb0a5..a9dd0c2b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -20,7 +20,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, - LasyerMultiBlockReqMeta, + LayerMultiBlockReqMeta, ReqMeta, ) from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( @@ -399,7 +399,7 @@ class KVPoolWorker: if not is_finish: logger.info("Layerwise get failed") self.get_event.clear() - req_meta = LasyerMultiBlockReqMeta( + req_meta = LayerMultiBlockReqMeta( request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id ) self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] @@ -455,7 +455,7 @@ class KVPoolWorker: if keys: keys = [list(row) for row in zip(*keys)] # [layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): - req_meta = LasyerMultiBlockReqMeta( + req_meta = LayerMultiBlockReqMeta( request.req_id, keys_multi_chunk, starts, @@ -602,8 +602,9 @@ class KVPoolWorker: ) multi_tp_keys.append(new_str) + pp_base_keys = multi_tp_keys.copy() for i in range(1, self.pp_size): - for item in keys: + for item in pp_base_keys: new_str = item.replace( # type: ignore[attr-defined] "@pp_rank:0", f"@pp_rank:{i}", 1 )