<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Before when we do put for KV Pool, we find the first non-existing key and put all the blocks starting from that index; however, if the prefix cache blocks is from another request, and some of the blocks are evicted due to LRU, we will be putting blocks that still exist in the pool, and causing MooncakeStore printing unnecessary logs in master service. What this PR does: Now we lookup all the keys and only put the ones that are missing. Fix lookup_scheduler in pool_worker so it handles GQA correctly. Fixes a few existing typos Add UT, written by codex <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: Pz1116 <zpbzpb123123@gmail.com> Co-authored-by: DreamerLeader <2270923832@qq.com> Co-authored-by: fems14 <1804143737@qq.com>
128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
import threading
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
|
|
import torch
|
|
|
|
if not hasattr(torch, "npu"):
|
|
torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined]
|
|
|
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
|
LayerMultiBlockReqMeta,
|
|
ReqMeta,
|
|
)
|
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
|
KVCacheStoreLayerSendingThread,
|
|
KVCacheStoreSendingThread,
|
|
)
|
|
|
|
|
|
class _FakeKey:
|
|
def __init__(self, value: str):
|
|
self._value = value
|
|
|
|
def to_string(self) -> str:
|
|
return self._value
|
|
|
|
|
|
class _FakeStore:
|
|
def __init__(self, exists_result: list[int]):
|
|
self.exists_result = exists_result
|
|
self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = []
|
|
|
|
def set_device(self):
|
|
return None
|
|
|
|
def exists(self, keys: list[str]) -> list[int]:
|
|
# Return exact number of states for requested keys.
|
|
return self.exists_result[: len(keys)]
|
|
|
|
def put(self, keys, addrs, sizes):
|
|
self.put_calls.append((list(keys), list(addrs), list(sizes)))
|
|
|
|
|
|
class _FakeTokenDatabase:
|
|
def process_tokens(self, token_len, block_hashes):
|
|
for i, _ in enumerate(block_hashes):
|
|
yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}")
|
|
|
|
def prepare_value(self, start, end, block_ids):
|
|
block_id = start // 16
|
|
return [1000 + block_id], [end - start], block_id
|
|
|
|
def prepare_value_layer(self, start, end, block_ids, layer_id):
|
|
block_id = start // 16
|
|
return [2000 + layer_id * 100 + block_id], [end - start]
|
|
|
|
|
|
class TestKVTransferMissingKeyPut(unittest.TestCase):
|
|
def test_sending_thread_only_puts_missing_keys(self):
|
|
store = _FakeStore(exists_result=[1, 0, 1, 0])
|
|
token_db = _FakeTokenDatabase()
|
|
thread = KVCacheStoreSendingThread(
|
|
m_store=store,
|
|
token_database=token_db,
|
|
block_size=16,
|
|
tp_rank=0,
|
|
dcp_size=1,
|
|
put_step=1,
|
|
kv_role="kv_producer",
|
|
ready_event=threading.Event(),
|
|
enable_kv_event=False,
|
|
)
|
|
|
|
req_meta = ReqMeta(
|
|
req_id="req-1",
|
|
token_len_chunk=64,
|
|
block_ids=[0, 1, 2, 3],
|
|
block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type]
|
|
current_event=None,
|
|
)
|
|
thread.add_stored_request("req-1")
|
|
thread.request_queue.put(req_meta)
|
|
thread._handle_request(req_meta)
|
|
|
|
self.assertEqual(len(store.put_calls), 1)
|
|
put_keys, put_addrs, put_sizes = store.put_calls[0]
|
|
self.assertEqual(put_keys, ["k1", "k3"])
|
|
self.assertEqual(put_addrs, [[1001], [1003]])
|
|
self.assertEqual(put_sizes, [[16], [16]])
|
|
|
|
def test_layer_sending_thread_only_puts_missing_keys(self):
|
|
store = _FakeStore(exists_result=[1, 0, 1, 0])
|
|
token_db = _FakeTokenDatabase()
|
|
thread = KVCacheStoreLayerSendingThread(
|
|
m_store=store,
|
|
token_database=token_db,
|
|
block_size=16,
|
|
tp_rank=0,
|
|
dcp_size=1,
|
|
put_step=1,
|
|
ready_event=threading.Event(),
|
|
num_layers=2,
|
|
enable_kv_event=False,
|
|
)
|
|
|
|
req_meta = LayerMultiBlockReqMeta(
|
|
req_id="req-2",
|
|
keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type]
|
|
starts=[0, 16, 32, 48],
|
|
ends=[16, 32, 48, 64],
|
|
block_ids=[0, 1, 2, 3],
|
|
layer_id=1,
|
|
is_last_chunk=False,
|
|
current_event=None,
|
|
)
|
|
thread.request_queue.put(req_meta)
|
|
thread._handle_request(req_meta)
|
|
|
|
self.assertEqual(len(store.put_calls), 1)
|
|
put_keys, put_addrs, put_sizes = store.put_calls[0]
|
|
self.assertEqual(put_keys, ["k1", "k3"])
|
|
self.assertEqual(put_addrs, [[2101], [2103]])
|
|
self.assertEqual(put_sizes, [[16], [16]])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|