From 0b48ddbc8b43cfd032fecabf83717511651a00a2 Mon Sep 17 00:00:00 2001 From: pz1116 <47019764+Pz1116@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:21:23 +0800 Subject: [PATCH] [Bugfix][0.18.0][KV Pool]Fix KV transfer put logic (#7718) ### What this PR does / why we need it? Before when we do put for KV Pool, we find the first non-existing key and put all the blocks starting from that index; however, if the prefix cache blocks is from another request, and some of the blocks are evicted due to LRU, we will be putting blocks that still exist in the pool, and causing MooncakeStore printing unnecessary logs in master service. What this PR does: Now we lookup all the keys and only put the ones that are missing. Fix lookup_scheduler in pool_worker so it handles GQA correctly. Fixes a few existing typos Add UT, written by codex ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? --------- Signed-off-by: Pz1116 Co-authored-by: DreamerLeader <2270923832@qq.com> Co-authored-by: fems14 <1804143737@qq.com> --- .../distributed/mooncake/test_kv_transfer.py | 127 ++++++++++++++++++ .../kv_pool/ascend_store/config_data.py | 2 +- .../kv_pool/ascend_store/kv_transfer.py | 66 ++++----- .../kv_pool/ascend_store/pool_worker.py | 9 +- 4 files changed, 168 insertions(+), 36 deletions(-) create mode 100644 tests/ut/distributed/mooncake/test_kv_transfer.py diff --git a/tests/ut/distributed/mooncake/test_kv_transfer.py b/tests/ut/distributed/mooncake/test_kv_transfer.py new file mode 100644 index 00000000..d0b31335 --- /dev/null +++ b/tests/ut/distributed/mooncake/test_kv_transfer.py @@ -0,0 +1,127 @@ +import threading +import unittest +from types import SimpleNamespace + +import torch + +if not hasattr(torch, "npu"): + torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined] + +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( + LayerMultiBlockReqMeta, + ReqMeta, +) +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( + KVCacheStoreLayerSendingThread, + KVCacheStoreSendingThread, +) + + +class _FakeKey: + def __init__(self, value: str): + self._value = value + + def to_string(self) -> str: + return self._value + + +class _FakeStore: + def __init__(self, exists_result: list[int]): + self.exists_result = exists_result + self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = [] + + def set_device(self): + return None + + def exists(self, keys: list[str]) -> list[int]: + # Return exact number of states for requested keys. + return self.exists_result[: len(keys)] + + def put(self, keys, addrs, sizes): + self.put_calls.append((list(keys), list(addrs), list(sizes))) + + +class _FakeTokenDatabase: + def process_tokens(self, token_len, block_hashes): + for i, _ in enumerate(block_hashes): + yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}") + + def prepare_value(self, start, end, block_ids): + block_id = start // 16 + return [1000 + block_id], [end - start], block_id + + def prepare_value_layer(self, start, end, block_ids, layer_id): + block_id = start // 16 + return [2000 + layer_id * 100 + block_id], [end - start] + + +class TestKVTransferMissingKeyPut(unittest.TestCase): + def test_sending_thread_only_puts_missing_keys(self): + store = _FakeStore(exists_result=[1, 0, 1, 0]) + token_db = _FakeTokenDatabase() + thread = KVCacheStoreSendingThread( + m_store=store, + token_database=token_db, + block_size=16, + tp_rank=0, + dcp_size=1, + put_step=1, + kv_role="kv_producer", + ready_event=threading.Event(), + enable_kv_event=False, + ) + + req_meta = ReqMeta( + req_id="req-1", + token_len_chunk=64, + block_ids=[0, 1, 2, 3], + block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type] + current_event=None, + ) + thread.add_stored_request("req-1") + thread.request_queue.put(req_meta) + thread._handle_request(req_meta) + + self.assertEqual(len(store.put_calls), 1) + put_keys, put_addrs, put_sizes = store.put_calls[0] + self.assertEqual(put_keys, ["k1", "k3"]) + self.assertEqual(put_addrs, [[1001], [1003]]) + self.assertEqual(put_sizes, [[16], [16]]) + + def test_layer_sending_thread_only_puts_missing_keys(self): + store = _FakeStore(exists_result=[1, 0, 1, 0]) + token_db = _FakeTokenDatabase() + thread = KVCacheStoreLayerSendingThread( + m_store=store, + token_database=token_db, + block_size=16, + tp_rank=0, + dcp_size=1, + put_step=1, + ready_event=threading.Event(), + num_layers=2, + enable_kv_event=False, + ) + + req_meta = LayerMultiBlockReqMeta( + req_id="req-2", + keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type] + starts=[0, 16, 32, 48], + ends=[16, 32, 48, 64], + block_ids=[0, 1, 2, 3], + layer_id=1, + is_last_chunk=False, + current_event=None, + ) + thread.request_queue.put(req_meta) + thread._handle_request(req_meta) + + self.assertEqual(len(store.put_calls), 1) + put_keys, put_addrs, put_sizes = store.put_calls[0] + self.assertEqual(put_keys, ["k1", "k3"]) + self.assertEqual(put_addrs, [[2101], [2103]]) + self.assertEqual(put_sizes, [[16], [16]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py index a7058cbb..2cde12d3 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py @@ -398,7 +398,7 @@ class AscendConnectorMetadata(KVConnectorMetadata): @dataclass -class LasyerMultiBlockReqMeta: +class LayerMultiBlockReqMeta: req_id: str keys: list[LayerPoolKey] starts: list[int] diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py index a6df8c39..a1cf7c53 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py @@ -14,7 +14,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend im # isort: off from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( ChunkedTokenDatabase, - LasyerMultiBlockReqMeta, + LayerMultiBlockReqMeta, ReqMeta, ) # isort: on @@ -48,7 +48,7 @@ class KVTransferThread(threading.Thread): def add_request( self, - request: ReqMeta | LasyerMultiBlockReqMeta, + request: ReqMeta | LayerMultiBlockReqMeta, ) -> torch.Tensor: self.request_queue.put(request) @@ -88,22 +88,20 @@ class KVTransferThread(threading.Thread): def lookup( self, keys: list[str], - ) -> int: + ) -> list[bool]: """ - Checks the existence of KV cache of the tokens from the cache engine. - :param tokens: the input tokens, with shape [seq_len] - :return: An int indicating how many prefix tokens are cached. + Check the existence of all keys from the cache engine. + :return: A bool list where True means the key exists in store. """ try: res = self.m_store.exists(keys) # type: ignore[assignment] + exists_list = [False] * len(keys) for index, value in enumerate(res): # type: ignore[arg-type] - if value != 1: - return index - # all tokens where found, return the maximal end + exists_list[index] = value == 1 + return exists_list except Exception as e: logger.error(f"Remote connection failed in contains: {e}") - return 0 - return len(keys) + return [False] * len(keys) def update_kv_event(self, event: list[BlockStored]): with self.kv_event_lock: @@ -159,39 +157,44 @@ class KVCacheStoreSendingThread(KVTransferThread): starts = [] ends = [] keys = [] + block_hashes = [] if req_id not in self.stored_requests: self.request_queue.task_done() return - for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes): + for index, (start, end, key) in enumerate(self.token_database.process_tokens(token_len, req_meta.block_hashes)): starts.append(start) ends.append(end) keys.append(key.to_string()) + block_hashes.append(req_meta.block_hashes[index]) if not self.dcp_size > 1: starts = starts[self.tp_rank % self.put_step :: self.put_step] ends = ends[self.tp_rank % self.put_step :: self.put_step] keys = keys[self.tp_rank % self.put_step :: self.put_step] + block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step] if not keys: self.dec_stored_request(req_id) return - skip_block_num = self.lookup(keys) + exists_states = self.lookup(keys) + missing_indices = [index for index, exists in enumerate(exists_states) if not exists] - if skip_block_num == len(keys): + if not missing_indices: self.dec_stored_request(req_id) return - starts = starts[skip_block_num:] - ends = ends[skip_block_num:] - keys = keys[skip_block_num:] + starts = [starts[index] for index in missing_indices] + ends = [ends[index] for index in missing_indices] + keys = [keys[index] for index in missing_indices] + block_hashes = [block_hashes[index] for index in missing_indices] logger.debug( - "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", + "Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s", len(keys), token_len // self.block_size, - skip_block_num, + len(missing_indices), req_id, ) @@ -206,7 +209,7 @@ class KVCacheStoreSendingThread(KVTransferThread): sizes = [] stored_events: list[BlockStored] = [] prev_key = None - new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]] + new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes] for index, start in enumerate(starts): addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids) addrs.append(addr) @@ -307,7 +310,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): self.request_queue.put(req_meta) def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta + self, req_meta: LayerMultiBlockReqMeta ): starts = req_meta.starts ends = req_meta.ends @@ -330,16 +333,17 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): for key in keys: key_list.append(key.to_string()) - skip_block_num = self.lookup(key_list) + exists_states = self.lookup(key_list) + missing_indices = [index for index, exists in enumerate(exists_states) if not exists] - if skip_block_num == len(key_list): + if not missing_indices: if is_last_chunk and layer_id == self.final_layer_id: self.set_finished_request(req_meta.req_id) return - starts = starts[skip_block_num:] - ends = ends[skip_block_num:] - key_list = key_list[skip_block_num:] + starts = [starts[index] for index in missing_indices] + ends = [ends[index] for index in missing_indices] + key_list = [key_list[index] for index in missing_indices] addr_list = [] size_list = [] @@ -359,10 +363,10 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): self.request_queue.task_done() logger.info( - "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", - len(keys), + "Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s", + len(key_list), total_block, - skip_block_num, + len(missing_indices), req_meta.req_id, ) @@ -384,12 +388,12 @@ class KVCacheStoreLayerRecvingThread(KVTransferThread): self.get_event = get_event def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta + self, req_meta: LayerMultiBlockReqMeta ) -> torch.Tensor: self.request_queue.put(req_meta) def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta + self, req_meta: LayerMultiBlockReqMeta ): addr_list = [] size_list = [] diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 4e6cb0a5..a9dd0c2b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -20,7 +20,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, - LasyerMultiBlockReqMeta, + LayerMultiBlockReqMeta, ReqMeta, ) from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( @@ -399,7 +399,7 @@ class KVPoolWorker: if not is_finish: logger.info("Layerwise get failed") self.get_event.clear() - req_meta = LasyerMultiBlockReqMeta( + req_meta = LayerMultiBlockReqMeta( request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id ) self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] @@ -455,7 +455,7 @@ class KVPoolWorker: if keys: keys = [list(row) for row in zip(*keys)] # [layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): - req_meta = LasyerMultiBlockReqMeta( + req_meta = LayerMultiBlockReqMeta( request.req_id, keys_multi_chunk, starts, @@ -602,8 +602,9 @@ class KVPoolWorker: ) multi_tp_keys.append(new_str) + pp_base_keys = multi_tp_keys.copy() for i in range(1, self.pp_size): - for item in keys: + for item in pp_base_keys: new_str = item.replace( # type: ignore[attr-defined] "@pp_rank:0", f"@pp_rank:{i}", 1 )