[Bugfix][0.18.0][KV Pool]Fix KV transfer put logic (#7718)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Before when we do put for KV Pool, we find the first non-existing key
and put all the blocks starting from that index; however, if the prefix
cache blocks is from another request, and some of the blocks are evicted
due to LRU, we will be putting blocks that still exist in the pool, and
causing MooncakeStore printing unnecessary logs in master service.

What this PR does:

Now we lookup all the keys and only put the ones that are missing.
Fix lookup_scheduler in pool_worker so it handles GQA correctly.
Fixes a few existing typos
Add UT, written by codex
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: Pz1116 <zpbzpb123123@gmail.com>
Co-authored-by: DreamerLeader <2270923832@qq.com>
Co-authored-by: fems14 <1804143737@qq.com>
This commit is contained in:
pz1116
2026-03-31 20:21:23 +08:00
committed by GitHub
parent 14411e911e
commit 0b48ddbc8b
4 changed files with 168 additions and 36 deletions

View File

@@ -0,0 +1,127 @@
import threading
import unittest
from types import SimpleNamespace
import torch
if not hasattr(torch, "npu"):
torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined]
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
LayerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
KVCacheStoreLayerSendingThread,
KVCacheStoreSendingThread,
)
class _FakeKey:
def __init__(self, value: str):
self._value = value
def to_string(self) -> str:
return self._value
class _FakeStore:
def __init__(self, exists_result: list[int]):
self.exists_result = exists_result
self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = []
def set_device(self):
return None
def exists(self, keys: list[str]) -> list[int]:
# Return exact number of states for requested keys.
return self.exists_result[: len(keys)]
def put(self, keys, addrs, sizes):
self.put_calls.append((list(keys), list(addrs), list(sizes)))
class _FakeTokenDatabase:
def process_tokens(self, token_len, block_hashes):
for i, _ in enumerate(block_hashes):
yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}")
def prepare_value(self, start, end, block_ids):
block_id = start // 16
return [1000 + block_id], [end - start], block_id
def prepare_value_layer(self, start, end, block_ids, layer_id):
block_id = start // 16
return [2000 + layer_id * 100 + block_id], [end - start]
class TestKVTransferMissingKeyPut(unittest.TestCase):
def test_sending_thread_only_puts_missing_keys(self):
store = _FakeStore(exists_result=[1, 0, 1, 0])
token_db = _FakeTokenDatabase()
thread = KVCacheStoreSendingThread(
m_store=store,
token_database=token_db,
block_size=16,
tp_rank=0,
dcp_size=1,
put_step=1,
kv_role="kv_producer",
ready_event=threading.Event(),
enable_kv_event=False,
)
req_meta = ReqMeta(
req_id="req-1",
token_len_chunk=64,
block_ids=[0, 1, 2, 3],
block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type]
current_event=None,
)
thread.add_stored_request("req-1")
thread.request_queue.put(req_meta)
thread._handle_request(req_meta)
self.assertEqual(len(store.put_calls), 1)
put_keys, put_addrs, put_sizes = store.put_calls[0]
self.assertEqual(put_keys, ["k1", "k3"])
self.assertEqual(put_addrs, [[1001], [1003]])
self.assertEqual(put_sizes, [[16], [16]])
def test_layer_sending_thread_only_puts_missing_keys(self):
store = _FakeStore(exists_result=[1, 0, 1, 0])
token_db = _FakeTokenDatabase()
thread = KVCacheStoreLayerSendingThread(
m_store=store,
token_database=token_db,
block_size=16,
tp_rank=0,
dcp_size=1,
put_step=1,
ready_event=threading.Event(),
num_layers=2,
enable_kv_event=False,
)
req_meta = LayerMultiBlockReqMeta(
req_id="req-2",
keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type]
starts=[0, 16, 32, 48],
ends=[16, 32, 48, 64],
block_ids=[0, 1, 2, 3],
layer_id=1,
is_last_chunk=False,
current_event=None,
)
thread.request_queue.put(req_meta)
thread._handle_request(req_meta)
self.assertEqual(len(store.put_calls), 1)
put_keys, put_addrs, put_sizes = store.put_calls[0]
self.assertEqual(put_keys, ["k1", "k3"])
self.assertEqual(put_addrs, [[2101], [2103]])
self.assertEqual(put_sizes, [[16], [16]])
if __name__ == "__main__":
unittest.main()

View File

@@ -398,7 +398,7 @@ class AscendConnectorMetadata(KVConnectorMetadata):
@dataclass
class LasyerMultiBlockReqMeta:
class LayerMultiBlockReqMeta:
req_id: str
keys: list[LayerPoolKey]
starts: list[int]

View File

@@ -14,7 +14,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend im
# isort: off
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
ChunkedTokenDatabase,
LasyerMultiBlockReqMeta,
LayerMultiBlockReqMeta,
ReqMeta,
)
# isort: on
@@ -48,7 +48,7 @@ class KVTransferThread(threading.Thread):
def add_request(
self,
request: ReqMeta | LasyerMultiBlockReqMeta,
request: ReqMeta | LayerMultiBlockReqMeta,
) -> torch.Tensor:
self.request_queue.put(request)
@@ -88,22 +88,20 @@ class KVTransferThread(threading.Thread):
def lookup(
self,
keys: list[str],
) -> int:
) -> list[bool]:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
Check the existence of all keys from the cache engine.
:return: A bool list where True means the key exists in store.
"""
try:
res = self.m_store.exists(keys) # type: ignore[assignment]
exists_list = [False] * len(keys)
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return index
# all tokens where found, return the maximal end
exists_list[index] = value == 1
return exists_list
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return 0
return len(keys)
return [False] * len(keys)
def update_kv_event(self, event: list[BlockStored]):
with self.kv_event_lock:
@@ -159,39 +157,44 @@ class KVCacheStoreSendingThread(KVTransferThread):
starts = []
ends = []
keys = []
block_hashes = []
if req_id not in self.stored_requests:
self.request_queue.task_done()
return
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
for index, (start, end, key) in enumerate(self.token_database.process_tokens(token_len, req_meta.block_hashes)):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[index])
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step :: self.put_step]
block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step]
if not keys:
self.dec_stored_request(req_id)
return
skip_block_num = self.lookup(keys)
exists_states = self.lookup(keys)
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
if skip_block_num == len(keys):
if not missing_indices:
self.dec_stored_request(req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]
starts = [starts[index] for index in missing_indices]
ends = [ends[index] for index in missing_indices]
keys = [keys[index] for index in missing_indices]
block_hashes = [block_hashes[index] for index in missing_indices]
logger.debug(
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
"Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
len(missing_indices),
req_id,
)
@@ -206,7 +209,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
sizes = []
stored_events: list[BlockStored] = []
prev_key = None
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]]
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
addrs.append(addr)
@@ -307,7 +310,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta
self, req_meta: LayerMultiBlockReqMeta
):
starts = req_meta.starts
ends = req_meta.ends
@@ -330,16 +333,17 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
for key in keys:
key_list.append(key.to_string())
skip_block_num = self.lookup(key_list)
exists_states = self.lookup(key_list)
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
if skip_block_num == len(key_list):
if not missing_indices:
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]
starts = [starts[index] for index in missing_indices]
ends = [ends[index] for index in missing_indices]
key_list = [key_list[index] for index in missing_indices]
addr_list = []
size_list = []
@@ -359,10 +363,10 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.request_queue.task_done()
logger.info(
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
len(keys),
"Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
len(key_list),
total_block,
skip_block_num,
len(missing_indices),
req_meta.req_id,
)
@@ -384,12 +388,12 @@ class KVCacheStoreLayerRecvingThread(KVTransferThread):
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta
self, req_meta: LayerMultiBlockReqMeta
) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta
self, req_meta: LayerMultiBlockReqMeta
):
addr_list = []
size_list = []

View File

@@ -20,7 +20,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
AscendConnectorMetadata,
ChunkedTokenDatabase,
KeyMetadata,
LasyerMultiBlockReqMeta,
LayerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
@@ -399,7 +399,7 @@ class KVPoolWorker:
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(
req_meta = LayerMultiBlockReqMeta(
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
@@ -455,7 +455,7 @@ class KVPoolWorker:
if keys:
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(
req_meta = LayerMultiBlockReqMeta(
request.req_id,
keys_multi_chunk,
starts,
@@ -602,8 +602,9 @@ class KVPoolWorker:
)
multi_tp_keys.append(new_str)
pp_base_keys = multi_tp_keys.copy()
for i in range(1, self.pp_size):
for item in keys:
for item in pp_base_keys:
new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1
)