128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
|
|
import threading
|
||
|
|
import unittest
|
||
|
|
from types import SimpleNamespace
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if not hasattr(torch, "npu"):
|
||
|
|
torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined]
|
||
|
|
|
||
|
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||
|
|
LayerMultiBlockReqMeta,
|
||
|
|
ReqMeta,
|
||
|
|
)
|
||
|
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
||
|
|
KVCacheStoreLayerSendingThread,
|
||
|
|
KVCacheStoreSendingThread,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeKey:
|
||
|
|
def __init__(self, value: str):
|
||
|
|
self._value = value
|
||
|
|
|
||
|
|
def to_string(self) -> str:
|
||
|
|
return self._value
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeStore:
|
||
|
|
def __init__(self, exists_result: list[int]):
|
||
|
|
self.exists_result = exists_result
|
||
|
|
self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = []
|
||
|
|
|
||
|
|
def set_device(self):
|
||
|
|
return None
|
||
|
|
|
||
|
|
def exists(self, keys: list[str]) -> list[int]:
|
||
|
|
# Return exact number of states for requested keys.
|
||
|
|
return self.exists_result[: len(keys)]
|
||
|
|
|
||
|
|
def put(self, keys, addrs, sizes):
|
||
|
|
self.put_calls.append((list(keys), list(addrs), list(sizes)))
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeTokenDatabase:
|
||
|
|
def process_tokens(self, token_len, block_hashes):
|
||
|
|
for i, _ in enumerate(block_hashes):
|
||
|
|
yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}")
|
||
|
|
|
||
|
|
def prepare_value(self, start, end, block_ids):
|
||
|
|
block_id = start // 16
|
||
|
|
return [1000 + block_id], [end - start], block_id
|
||
|
|
|
||
|
|
def prepare_value_layer(self, start, end, block_ids, layer_id):
|
||
|
|
block_id = start // 16
|
||
|
|
return [2000 + layer_id * 100 + block_id], [end - start]
|
||
|
|
|
||
|
|
|
||
|
|
class TestKVTransferMissingKeyPut(unittest.TestCase):
|
||
|
|
def test_sending_thread_only_puts_missing_keys(self):
|
||
|
|
store = _FakeStore(exists_result=[1, 0, 1, 0])
|
||
|
|
token_db = _FakeTokenDatabase()
|
||
|
|
thread = KVCacheStoreSendingThread(
|
||
|
|
m_store=store,
|
||
|
|
token_database=token_db,
|
||
|
|
block_size=16,
|
||
|
|
tp_rank=0,
|
||
|
|
dcp_size=1,
|
||
|
|
put_step=1,
|
||
|
|
kv_role="kv_producer",
|
||
|
|
ready_event=threading.Event(),
|
||
|
|
enable_kv_event=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
req_meta = ReqMeta(
|
||
|
|
req_id="req-1",
|
||
|
|
token_len_chunk=64,
|
||
|
|
block_ids=[0, 1, 2, 3],
|
||
|
|
block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type]
|
||
|
|
current_event=None,
|
||
|
|
)
|
||
|
|
thread.add_stored_request("req-1")
|
||
|
|
thread.request_queue.put(req_meta)
|
||
|
|
thread._handle_request(req_meta)
|
||
|
|
|
||
|
|
self.assertEqual(len(store.put_calls), 1)
|
||
|
|
put_keys, put_addrs, put_sizes = store.put_calls[0]
|
||
|
|
self.assertEqual(put_keys, ["k1", "k3"])
|
||
|
|
self.assertEqual(put_addrs, [[1001], [1003]])
|
||
|
|
self.assertEqual(put_sizes, [[16], [16]])
|
||
|
|
|
||
|
|
def test_layer_sending_thread_only_puts_missing_keys(self):
|
||
|
|
store = _FakeStore(exists_result=[1, 0, 1, 0])
|
||
|
|
token_db = _FakeTokenDatabase()
|
||
|
|
thread = KVCacheStoreLayerSendingThread(
|
||
|
|
m_store=store,
|
||
|
|
token_database=token_db,
|
||
|
|
block_size=16,
|
||
|
|
tp_rank=0,
|
||
|
|
dcp_size=1,
|
||
|
|
put_step=1,
|
||
|
|
ready_event=threading.Event(),
|
||
|
|
num_layers=2,
|
||
|
|
enable_kv_event=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
req_meta = LayerMultiBlockReqMeta(
|
||
|
|
req_id="req-2",
|
||
|
|
keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type]
|
||
|
|
starts=[0, 16, 32, 48],
|
||
|
|
ends=[16, 32, 48, 64],
|
||
|
|
block_ids=[0, 1, 2, 3],
|
||
|
|
layer_id=1,
|
||
|
|
is_last_chunk=False,
|
||
|
|
current_event=None,
|
||
|
|
)
|
||
|
|
thread.request_queue.put(req_meta)
|
||
|
|
thread._handle_request(req_meta)
|
||
|
|
|
||
|
|
self.assertEqual(len(store.put_calls), 1)
|
||
|
|
put_keys, put_addrs, put_sizes = store.put_calls[0]
|
||
|
|
self.assertEqual(put_keys, ["k1", "k3"])
|
||
|
|
self.assertEqual(put_addrs, [[2101], [2103]])
|
||
|
|
self.assertEqual(put_sizes, [[16], [16]])
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
unittest.main()
|