[Bugfix][0.18.0][KV Pool]Fix KV transfer put logic (#7718)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Before when we do put for KV Pool, we find the first non-existing key and put all the blocks starting from that index; however, if the prefix cache blocks is from another request, and some of the blocks are evicted due to LRU, we will be putting blocks that still exist in the pool, and causing MooncakeStore printing unnecessary logs in master service. What this PR does: Now we lookup all the keys and only put the ones that are missing. Fix lookup_scheduler in pool_worker so it handles GQA correctly. Fixes a few existing typos Add UT, written by codex <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: Pz1116 <zpbzpb123123@gmail.com> Co-authored-by: DreamerLeader <2270923832@qq.com> Co-authored-by: fems14 <1804143737@qq.com>
This commit is contained in:
127
tests/ut/distributed/mooncake/test_kv_transfer.py
Normal file
127
tests/ut/distributed/mooncake/test_kv_transfer.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import threading
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not hasattr(torch, "npu"):
|
||||||
|
torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||||
|
LayerMultiBlockReqMeta,
|
||||||
|
ReqMeta,
|
||||||
|
)
|
||||||
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
||||||
|
KVCacheStoreLayerSendingThread,
|
||||||
|
KVCacheStoreSendingThread,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeKey:
|
||||||
|
def __init__(self, value: str):
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeStore:
|
||||||
|
def __init__(self, exists_result: list[int]):
|
||||||
|
self.exists_result = exists_result
|
||||||
|
self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = []
|
||||||
|
|
||||||
|
def set_device(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def exists(self, keys: list[str]) -> list[int]:
|
||||||
|
# Return exact number of states for requested keys.
|
||||||
|
return self.exists_result[: len(keys)]
|
||||||
|
|
||||||
|
def put(self, keys, addrs, sizes):
|
||||||
|
self.put_calls.append((list(keys), list(addrs), list(sizes)))
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTokenDatabase:
|
||||||
|
def process_tokens(self, token_len, block_hashes):
|
||||||
|
for i, _ in enumerate(block_hashes):
|
||||||
|
yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}")
|
||||||
|
|
||||||
|
def prepare_value(self, start, end, block_ids):
|
||||||
|
block_id = start // 16
|
||||||
|
return [1000 + block_id], [end - start], block_id
|
||||||
|
|
||||||
|
def prepare_value_layer(self, start, end, block_ids, layer_id):
|
||||||
|
block_id = start // 16
|
||||||
|
return [2000 + layer_id * 100 + block_id], [end - start]
|
||||||
|
|
||||||
|
|
||||||
|
class TestKVTransferMissingKeyPut(unittest.TestCase):
|
||||||
|
def test_sending_thread_only_puts_missing_keys(self):
|
||||||
|
store = _FakeStore(exists_result=[1, 0, 1, 0])
|
||||||
|
token_db = _FakeTokenDatabase()
|
||||||
|
thread = KVCacheStoreSendingThread(
|
||||||
|
m_store=store,
|
||||||
|
token_database=token_db,
|
||||||
|
block_size=16,
|
||||||
|
tp_rank=0,
|
||||||
|
dcp_size=1,
|
||||||
|
put_step=1,
|
||||||
|
kv_role="kv_producer",
|
||||||
|
ready_event=threading.Event(),
|
||||||
|
enable_kv_event=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
req_meta = ReqMeta(
|
||||||
|
req_id="req-1",
|
||||||
|
token_len_chunk=64,
|
||||||
|
block_ids=[0, 1, 2, 3],
|
||||||
|
block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type]
|
||||||
|
current_event=None,
|
||||||
|
)
|
||||||
|
thread.add_stored_request("req-1")
|
||||||
|
thread.request_queue.put(req_meta)
|
||||||
|
thread._handle_request(req_meta)
|
||||||
|
|
||||||
|
self.assertEqual(len(store.put_calls), 1)
|
||||||
|
put_keys, put_addrs, put_sizes = store.put_calls[0]
|
||||||
|
self.assertEqual(put_keys, ["k1", "k3"])
|
||||||
|
self.assertEqual(put_addrs, [[1001], [1003]])
|
||||||
|
self.assertEqual(put_sizes, [[16], [16]])
|
||||||
|
|
||||||
|
def test_layer_sending_thread_only_puts_missing_keys(self):
|
||||||
|
store = _FakeStore(exists_result=[1, 0, 1, 0])
|
||||||
|
token_db = _FakeTokenDatabase()
|
||||||
|
thread = KVCacheStoreLayerSendingThread(
|
||||||
|
m_store=store,
|
||||||
|
token_database=token_db,
|
||||||
|
block_size=16,
|
||||||
|
tp_rank=0,
|
||||||
|
dcp_size=1,
|
||||||
|
put_step=1,
|
||||||
|
ready_event=threading.Event(),
|
||||||
|
num_layers=2,
|
||||||
|
enable_kv_event=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
req_meta = LayerMultiBlockReqMeta(
|
||||||
|
req_id="req-2",
|
||||||
|
keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type]
|
||||||
|
starts=[0, 16, 32, 48],
|
||||||
|
ends=[16, 32, 48, 64],
|
||||||
|
block_ids=[0, 1, 2, 3],
|
||||||
|
layer_id=1,
|
||||||
|
is_last_chunk=False,
|
||||||
|
current_event=None,
|
||||||
|
)
|
||||||
|
thread.request_queue.put(req_meta)
|
||||||
|
thread._handle_request(req_meta)
|
||||||
|
|
||||||
|
self.assertEqual(len(store.put_calls), 1)
|
||||||
|
put_keys, put_addrs, put_sizes = store.put_calls[0]
|
||||||
|
self.assertEqual(put_keys, ["k1", "k3"])
|
||||||
|
self.assertEqual(put_addrs, [[2101], [2103]])
|
||||||
|
self.assertEqual(put_sizes, [[16], [16]])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -398,7 +398,7 @@ class AscendConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LasyerMultiBlockReqMeta:
|
class LayerMultiBlockReqMeta:
|
||||||
req_id: str
|
req_id: str
|
||||||
keys: list[LayerPoolKey]
|
keys: list[LayerPoolKey]
|
||||||
starts: list[int]
|
starts: list[int]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend im
|
|||||||
# isort: off
|
# isort: off
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||||
ChunkedTokenDatabase,
|
ChunkedTokenDatabase,
|
||||||
LasyerMultiBlockReqMeta,
|
LayerMultiBlockReqMeta,
|
||||||
ReqMeta,
|
ReqMeta,
|
||||||
)
|
)
|
||||||
# isort: on
|
# isort: on
|
||||||
@@ -48,7 +48,7 @@ class KVTransferThread(threading.Thread):
|
|||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request: ReqMeta | LasyerMultiBlockReqMeta,
|
request: ReqMeta | LayerMultiBlockReqMeta,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self.request_queue.put(request)
|
self.request_queue.put(request)
|
||||||
|
|
||||||
@@ -88,22 +88,20 @@ class KVTransferThread(threading.Thread):
|
|||||||
def lookup(
|
def lookup(
|
||||||
self,
|
self,
|
||||||
keys: list[str],
|
keys: list[str],
|
||||||
) -> int:
|
) -> list[bool]:
|
||||||
"""
|
"""
|
||||||
Checks the existence of KV cache of the tokens from the cache engine.
|
Check the existence of all keys from the cache engine.
|
||||||
:param tokens: the input tokens, with shape [seq_len]
|
:return: A bool list where True means the key exists in store.
|
||||||
:return: An int indicating how many prefix tokens are cached.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
res = self.m_store.exists(keys) # type: ignore[assignment]
|
res = self.m_store.exists(keys) # type: ignore[assignment]
|
||||||
|
exists_list = [False] * len(keys)
|
||||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||||
if value != 1:
|
exists_list[index] = value == 1
|
||||||
return index
|
return exists_list
|
||||||
# all tokens where found, return the maximal end
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Remote connection failed in contains: {e}")
|
logger.error(f"Remote connection failed in contains: {e}")
|
||||||
return 0
|
return [False] * len(keys)
|
||||||
return len(keys)
|
|
||||||
|
|
||||||
def update_kv_event(self, event: list[BlockStored]):
|
def update_kv_event(self, event: list[BlockStored]):
|
||||||
with self.kv_event_lock:
|
with self.kv_event_lock:
|
||||||
@@ -159,39 +157,44 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
keys = []
|
keys = []
|
||||||
|
block_hashes = []
|
||||||
if req_id not in self.stored_requests:
|
if req_id not in self.stored_requests:
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
|
for index, (start, end, key) in enumerate(self.token_database.process_tokens(token_len, req_meta.block_hashes)):
|
||||||
starts.append(start)
|
starts.append(start)
|
||||||
ends.append(end)
|
ends.append(end)
|
||||||
keys.append(key.to_string())
|
keys.append(key.to_string())
|
||||||
|
block_hashes.append(req_meta.block_hashes[index])
|
||||||
|
|
||||||
if not self.dcp_size > 1:
|
if not self.dcp_size > 1:
|
||||||
starts = starts[self.tp_rank % self.put_step :: self.put_step]
|
starts = starts[self.tp_rank % self.put_step :: self.put_step]
|
||||||
ends = ends[self.tp_rank % self.put_step :: self.put_step]
|
ends = ends[self.tp_rank % self.put_step :: self.put_step]
|
||||||
keys = keys[self.tp_rank % self.put_step :: self.put_step]
|
keys = keys[self.tp_rank % self.put_step :: self.put_step]
|
||||||
|
block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step]
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
self.dec_stored_request(req_id)
|
self.dec_stored_request(req_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
skip_block_num = self.lookup(keys)
|
exists_states = self.lookup(keys)
|
||||||
|
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
|
||||||
|
|
||||||
if skip_block_num == len(keys):
|
if not missing_indices:
|
||||||
self.dec_stored_request(req_id)
|
self.dec_stored_request(req_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
starts = starts[skip_block_num:]
|
starts = [starts[index] for index in missing_indices]
|
||||||
ends = ends[skip_block_num:]
|
ends = [ends[index] for index in missing_indices]
|
||||||
keys = keys[skip_block_num:]
|
keys = [keys[index] for index in missing_indices]
|
||||||
|
block_hashes = [block_hashes[index] for index in missing_indices]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
|
"Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
|
||||||
len(keys),
|
len(keys),
|
||||||
token_len // self.block_size,
|
token_len // self.block_size,
|
||||||
skip_block_num,
|
len(missing_indices),
|
||||||
req_id,
|
req_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -206,7 +209,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
sizes = []
|
sizes = []
|
||||||
stored_events: list[BlockStored] = []
|
stored_events: list[BlockStored] = []
|
||||||
prev_key = None
|
prev_key = None
|
||||||
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]]
|
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
|
||||||
for index, start in enumerate(starts):
|
for index, start in enumerate(starts):
|
||||||
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
|
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
|
||||||
addrs.append(addr)
|
addrs.append(addr)
|
||||||
@@ -307,7 +310,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
self.request_queue.put(req_meta)
|
self.request_queue.put(req_meta)
|
||||||
|
|
||||||
def _handle_request( # type: ignore[override]
|
def _handle_request( # type: ignore[override]
|
||||||
self, req_meta: LasyerMultiBlockReqMeta
|
self, req_meta: LayerMultiBlockReqMeta
|
||||||
):
|
):
|
||||||
starts = req_meta.starts
|
starts = req_meta.starts
|
||||||
ends = req_meta.ends
|
ends = req_meta.ends
|
||||||
@@ -330,16 +333,17 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
|
|
||||||
skip_block_num = self.lookup(key_list)
|
exists_states = self.lookup(key_list)
|
||||||
|
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
|
||||||
|
|
||||||
if skip_block_num == len(key_list):
|
if not missing_indices:
|
||||||
if is_last_chunk and layer_id == self.final_layer_id:
|
if is_last_chunk and layer_id == self.final_layer_id:
|
||||||
self.set_finished_request(req_meta.req_id)
|
self.set_finished_request(req_meta.req_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
starts = starts[skip_block_num:]
|
starts = [starts[index] for index in missing_indices]
|
||||||
ends = ends[skip_block_num:]
|
ends = [ends[index] for index in missing_indices]
|
||||||
key_list = key_list[skip_block_num:]
|
key_list = [key_list[index] for index in missing_indices]
|
||||||
|
|
||||||
addr_list = []
|
addr_list = []
|
||||||
size_list = []
|
size_list = []
|
||||||
@@ -359,10 +363,10 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
|
"Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
|
||||||
len(keys),
|
len(key_list),
|
||||||
total_block,
|
total_block,
|
||||||
skip_block_num,
|
len(missing_indices),
|
||||||
req_meta.req_id,
|
req_meta.req_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -384,12 +388,12 @@ class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
|||||||
self.get_event = get_event
|
self.get_event = get_event
|
||||||
|
|
||||||
def add_request( # type: ignore[override]
|
def add_request( # type: ignore[override]
|
||||||
self, req_meta: LasyerMultiBlockReqMeta
|
self, req_meta: LayerMultiBlockReqMeta
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self.request_queue.put(req_meta)
|
self.request_queue.put(req_meta)
|
||||||
|
|
||||||
def _handle_request( # type: ignore[override]
|
def _handle_request( # type: ignore[override]
|
||||||
self, req_meta: LasyerMultiBlockReqMeta
|
self, req_meta: LayerMultiBlockReqMeta
|
||||||
):
|
):
|
||||||
addr_list = []
|
addr_list = []
|
||||||
size_list = []
|
size_list = []
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
|
|||||||
AscendConnectorMetadata,
|
AscendConnectorMetadata,
|
||||||
ChunkedTokenDatabase,
|
ChunkedTokenDatabase,
|
||||||
KeyMetadata,
|
KeyMetadata,
|
||||||
LasyerMultiBlockReqMeta,
|
LayerMultiBlockReqMeta,
|
||||||
ReqMeta,
|
ReqMeta,
|
||||||
)
|
)
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
||||||
@@ -399,7 +399,7 @@ class KVPoolWorker:
|
|||||||
if not is_finish:
|
if not is_finish:
|
||||||
logger.info("Layerwise get failed")
|
logger.info("Layerwise get failed")
|
||||||
self.get_event.clear()
|
self.get_event.clear()
|
||||||
req_meta = LasyerMultiBlockReqMeta(
|
req_meta = LayerMultiBlockReqMeta(
|
||||||
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
|
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
|
||||||
)
|
)
|
||||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||||
@@ -455,7 +455,7 @@ class KVPoolWorker:
|
|||||||
if keys:
|
if keys:
|
||||||
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
|
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
|
||||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||||
req_meta = LasyerMultiBlockReqMeta(
|
req_meta = LayerMultiBlockReqMeta(
|
||||||
request.req_id,
|
request.req_id,
|
||||||
keys_multi_chunk,
|
keys_multi_chunk,
|
||||||
starts,
|
starts,
|
||||||
@@ -602,8 +602,9 @@ class KVPoolWorker:
|
|||||||
)
|
)
|
||||||
multi_tp_keys.append(new_str)
|
multi_tp_keys.append(new_str)
|
||||||
|
|
||||||
|
pp_base_keys = multi_tp_keys.copy()
|
||||||
for i in range(1, self.pp_size):
|
for i in range(1, self.pp_size):
|
||||||
for item in keys:
|
for item in pp_base_keys:
|
||||||
new_str = item.replace( # type: ignore[attr-defined]
|
new_str = item.replace( # type: ignore[attr-defined]
|
||||||
"@pp_rank:0", f"@pp_rank:{i}", 1
|
"@pp_rank:0", f"@pp_rank:{i}", 1
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user