[Bugfix][0.18.0][KV Pool]Fix KV transfer put logic (#7718)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Before when we do put for KV Pool, we find the first non-existing key
and put all the blocks starting from that index; however, if the prefix
cache blocks is from another request, and some of the blocks are evicted
due to LRU, we will be putting blocks that still exist in the pool, and
causing MooncakeStore printing unnecessary logs in master service.

What this PR does:

Now we lookup all the keys and only put the ones that are missing.
Fix lookup_scheduler in pool_worker so it handles GQA correctly.
Fixes a few existing typos
Add UT, written by codex
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: Pz1116 <zpbzpb123123@gmail.com>
Co-authored-by: DreamerLeader <2270923832@qq.com>
Co-authored-by: fems14 <1804143737@qq.com>
This commit is contained in:
pz1116
2026-03-31 20:21:23 +08:00
committed by GitHub
parent 14411e911e
commit 0b48ddbc8b
4 changed files with 168 additions and 36 deletions

View File

@@ -0,0 +1,127 @@
import threading
import unittest
from types import SimpleNamespace
import torch
if not hasattr(torch, "npu"):
torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined]
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
LayerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
KVCacheStoreLayerSendingThread,
KVCacheStoreSendingThread,
)
class _FakeKey:
def __init__(self, value: str):
self._value = value
def to_string(self) -> str:
return self._value
class _FakeStore:
def __init__(self, exists_result: list[int]):
self.exists_result = exists_result
self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = []
def set_device(self):
return None
def exists(self, keys: list[str]) -> list[int]:
# Return exact number of states for requested keys.
return self.exists_result[: len(keys)]
def put(self, keys, addrs, sizes):
self.put_calls.append((list(keys), list(addrs), list(sizes)))
class _FakeTokenDatabase:
def process_tokens(self, token_len, block_hashes):
for i, _ in enumerate(block_hashes):
yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}")
def prepare_value(self, start, end, block_ids):
block_id = start // 16
return [1000 + block_id], [end - start], block_id
def prepare_value_layer(self, start, end, block_ids, layer_id):
block_id = start // 16
return [2000 + layer_id * 100 + block_id], [end - start]
class TestKVTransferMissingKeyPut(unittest.TestCase):
def test_sending_thread_only_puts_missing_keys(self):
store = _FakeStore(exists_result=[1, 0, 1, 0])
token_db = _FakeTokenDatabase()
thread = KVCacheStoreSendingThread(
m_store=store,
token_database=token_db,
block_size=16,
tp_rank=0,
dcp_size=1,
put_step=1,
kv_role="kv_producer",
ready_event=threading.Event(),
enable_kv_event=False,
)
req_meta = ReqMeta(
req_id="req-1",
token_len_chunk=64,
block_ids=[0, 1, 2, 3],
block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type]
current_event=None,
)
thread.add_stored_request("req-1")
thread.request_queue.put(req_meta)
thread._handle_request(req_meta)
self.assertEqual(len(store.put_calls), 1)
put_keys, put_addrs, put_sizes = store.put_calls[0]
self.assertEqual(put_keys, ["k1", "k3"])
self.assertEqual(put_addrs, [[1001], [1003]])
self.assertEqual(put_sizes, [[16], [16]])
def test_layer_sending_thread_only_puts_missing_keys(self):
store = _FakeStore(exists_result=[1, 0, 1, 0])
token_db = _FakeTokenDatabase()
thread = KVCacheStoreLayerSendingThread(
m_store=store,
token_database=token_db,
block_size=16,
tp_rank=0,
dcp_size=1,
put_step=1,
ready_event=threading.Event(),
num_layers=2,
enable_kv_event=False,
)
req_meta = LayerMultiBlockReqMeta(
req_id="req-2",
keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type]
starts=[0, 16, 32, 48],
ends=[16, 32, 48, 64],
block_ids=[0, 1, 2, 3],
layer_id=1,
is_last_chunk=False,
current_event=None,
)
thread.request_queue.put(req_meta)
thread._handle_request(req_meta)
self.assertEqual(len(store.put_calls), 1)
put_keys, put_addrs, put_sizes = store.put_calls[0]
self.assertEqual(put_keys, ["k1", "k3"])
self.assertEqual(put_addrs, [[2101], [2103]])
self.assertEqual(put_sizes, [[16], [16]])
if __name__ == "__main__":
unittest.main()

View File

@@ -398,7 +398,7 @@ class AscendConnectorMetadata(KVConnectorMetadata):
@dataclass @dataclass
class LasyerMultiBlockReqMeta: class LayerMultiBlockReqMeta:
req_id: str req_id: str
keys: list[LayerPoolKey] keys: list[LayerPoolKey]
starts: list[int] starts: list[int]

View File

@@ -14,7 +14,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend im
# isort: off # isort: off
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
ChunkedTokenDatabase, ChunkedTokenDatabase,
LasyerMultiBlockReqMeta, LayerMultiBlockReqMeta,
ReqMeta, ReqMeta,
) )
# isort: on # isort: on
@@ -48,7 +48,7 @@ class KVTransferThread(threading.Thread):
def add_request( def add_request(
self, self,
request: ReqMeta | LasyerMultiBlockReqMeta, request: ReqMeta | LayerMultiBlockReqMeta,
) -> torch.Tensor: ) -> torch.Tensor:
self.request_queue.put(request) self.request_queue.put(request)
@@ -88,22 +88,20 @@ class KVTransferThread(threading.Thread):
def lookup( def lookup(
self, self,
keys: list[str], keys: list[str],
) -> int: ) -> list[bool]:
""" """
Checks the existence of KV cache of the tokens from the cache engine. Check the existence of all keys from the cache engine.
:param tokens: the input tokens, with shape [seq_len] :return: A bool list where True means the key exists in store.
:return: An int indicating how many prefix tokens are cached.
""" """
try: try:
res = self.m_store.exists(keys) # type: ignore[assignment] res = self.m_store.exists(keys) # type: ignore[assignment]
exists_list = [False] * len(keys)
for index, value in enumerate(res): # type: ignore[arg-type] for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1: exists_list[index] = value == 1
return index return exists_list
# all tokens where found, return the maximal end
except Exception as e: except Exception as e:
logger.error(f"Remote connection failed in contains: {e}") logger.error(f"Remote connection failed in contains: {e}")
return 0 return [False] * len(keys)
return len(keys)
def update_kv_event(self, event: list[BlockStored]): def update_kv_event(self, event: list[BlockStored]):
with self.kv_event_lock: with self.kv_event_lock:
@@ -159,39 +157,44 @@ class KVCacheStoreSendingThread(KVTransferThread):
starts = [] starts = []
ends = [] ends = []
keys = [] keys = []
block_hashes = []
if req_id not in self.stored_requests: if req_id not in self.stored_requests:
self.request_queue.task_done() self.request_queue.task_done()
return return
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes): for index, (start, end, key) in enumerate(self.token_database.process_tokens(token_len, req_meta.block_hashes)):
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
keys.append(key.to_string()) keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[index])
if not self.dcp_size > 1: if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step :: self.put_step] starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step :: self.put_step] ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step :: self.put_step] keys = keys[self.tp_rank % self.put_step :: self.put_step]
block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step]
if not keys: if not keys:
self.dec_stored_request(req_id) self.dec_stored_request(req_id)
return return
skip_block_num = self.lookup(keys) exists_states = self.lookup(keys)
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
if skip_block_num == len(keys): if not missing_indices:
self.dec_stored_request(req_id) self.dec_stored_request(req_id)
return return
starts = starts[skip_block_num:] starts = [starts[index] for index in missing_indices]
ends = ends[skip_block_num:] ends = [ends[index] for index in missing_indices]
keys = keys[skip_block_num:] keys = [keys[index] for index in missing_indices]
block_hashes = [block_hashes[index] for index in missing_indices]
logger.debug( logger.debug(
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", "Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
len(keys), len(keys),
token_len // self.block_size, token_len // self.block_size,
skip_block_num, len(missing_indices),
req_id, req_id,
) )
@@ -206,7 +209,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
sizes = [] sizes = []
stored_events: list[BlockStored] = [] stored_events: list[BlockStored] = []
prev_key = None prev_key = None
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]] new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
for index, start in enumerate(starts): for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids) addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
addrs.append(addr) addrs.append(addr)
@@ -307,7 +310,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.request_queue.put(req_meta) self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override] def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta self, req_meta: LayerMultiBlockReqMeta
): ):
starts = req_meta.starts starts = req_meta.starts
ends = req_meta.ends ends = req_meta.ends
@@ -330,16 +333,17 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
for key in keys: for key in keys:
key_list.append(key.to_string()) key_list.append(key.to_string())
skip_block_num = self.lookup(key_list) exists_states = self.lookup(key_list)
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
if skip_block_num == len(key_list): if not missing_indices:
if is_last_chunk and layer_id == self.final_layer_id: if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id) self.set_finished_request(req_meta.req_id)
return return
starts = starts[skip_block_num:] starts = [starts[index] for index in missing_indices]
ends = ends[skip_block_num:] ends = [ends[index] for index in missing_indices]
key_list = key_list[skip_block_num:] key_list = [key_list[index] for index in missing_indices]
addr_list = [] addr_list = []
size_list = [] size_list = []
@@ -359,10 +363,10 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.request_queue.task_done() self.request_queue.task_done()
logger.info( logger.info(
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", "Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
len(keys), len(key_list),
total_block, total_block,
skip_block_num, len(missing_indices),
req_meta.req_id, req_meta.req_id,
) )
@@ -384,12 +388,12 @@ class KVCacheStoreLayerRecvingThread(KVTransferThread):
self.get_event = get_event self.get_event = get_event
def add_request( # type: ignore[override] def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta self, req_meta: LayerMultiBlockReqMeta
) -> torch.Tensor: ) -> torch.Tensor:
self.request_queue.put(req_meta) self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override] def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta self, req_meta: LayerMultiBlockReqMeta
): ):
addr_list = [] addr_list = []
size_list = [] size_list = []

View File

@@ -20,7 +20,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
AscendConnectorMetadata, AscendConnectorMetadata,
ChunkedTokenDatabase, ChunkedTokenDatabase,
KeyMetadata, KeyMetadata,
LasyerMultiBlockReqMeta, LayerMultiBlockReqMeta,
ReqMeta, ReqMeta,
) )
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
@@ -399,7 +399,7 @@ class KVPoolWorker:
if not is_finish: if not is_finish:
logger.info("Layerwise get failed") logger.info("Layerwise get failed")
self.get_event.clear() self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta( req_meta = LayerMultiBlockReqMeta(
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
) )
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
@@ -455,7 +455,7 @@ class KVPoolWorker:
if keys: if keys:
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num] keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys): for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta( req_meta = LayerMultiBlockReqMeta(
request.req_id, request.req_id,
keys_multi_chunk, keys_multi_chunk,
starts, starts,
@@ -602,8 +602,9 @@ class KVPoolWorker:
) )
multi_tp_keys.append(new_str) multi_tp_keys.append(new_str)
pp_base_keys = multi_tp_keys.copy()
for i in range(1, self.pp_size): for i in range(1, self.pp_size):
for item in keys: for item in pp_base_keys:
new_str = item.replace( # type: ignore[attr-defined] new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1 "@pp_rank:0", f"@pp_rank:{i}", 1
) )