[bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030)

### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-12-15 21:56:05 +08:00
committed by GitHub
parent c064d11fd7
commit b662d914a4
5 changed files with 188 additions and 199 deletions

View File

@@ -1,29 +1,31 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from typing import Any
import torch
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kvpool.backend.backend import Backend
# isort: off
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
LasyerMultiBlockReqMeta
)
from vllm_ascend.distributed.kvpool.config_data import (
ChunkedTokenDatabase,
LasyerMultiBlockReqMeta,
ReqMeta,
)
# isort: on
class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event,
name: str):
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.m_store = m_store
self.ready_event = ready_event
self.block_size = block_size
self.tp_rank = tp_rank
self.dcp_size = dcp_size
self.token_database = token_database
@@ -35,22 +37,9 @@ class KVTransferThread(threading.Thread):
def add_request(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
is_last_chunk: Optional[bool] = None,
request: ReqMeta,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"token_len": token_len,
"block_ids": block_ids,
"block_hashes": block_hashes,
"mask_num": mask_num,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
self.request_queue.put(request)
def get_and_clear_finished_requests(self) -> set[str]:
"""
@@ -82,50 +71,98 @@ class KVTransferThread(threading.Thread):
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
def _handle_request(self, req_meta: Any):
pass
def lookup(
self,
keys: list[str],
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
try:
res = self.m_store.exists(keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return index
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return 0
return len(keys)
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, put_step: int,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step
def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
block_hashes = req_meta["block_hashes"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
addr_list = []
size_list = []
key_list = []
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.token_len_chunk
block_ids = req_meta.block_ids
req_id = req_meta.req_id
is_last_chunk = req_meta.is_last_chunk
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
token_len, req_meta.block_hashes):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
if not keys:
if is_last_chunk:
self.set_finished_request(req_id)
return
skip_block_num = self.lookup(keys)
if skip_block_num == len(keys):
if is_last_chunk:
self.set_finished_request(req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
req_id,
)
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
if self.dcp_size > 1:
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank %
self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if keys:
self.m_store.put(keys, addrs, sizes)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
@@ -134,27 +171,28 @@ class KVCacheStoreSendingThread(KVTransferThread):
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event):
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
block_hashes = req_meta["block_hashes"]
def _handle_request(self, req_meta: ReqMeta):
req_id = req_meta.req_id
mask_num = (
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
start, end, req_meta.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
@@ -175,10 +213,11 @@ class KVCacheStoreRecvingThread(KVTransferThread):
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, put_step: int,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event, num_layers: int):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
@@ -187,43 +226,74 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.put_step = put_step
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self, req_meta: ReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
starts = req_meta.starts
ends = req_meta.ends
keys = req_meta.keys
layer_id = req_meta.layer_id
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
if not keys:
if is_last_chunk:
self.set_finished_request(req_meta.req_id)
return
key_list = []
for key in keys:
key_list.append(key.to_string())
skip_block_num = self.lookup(key_list)
if skip_block_num == len(key_list):
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
for index, key in enumerate(key_list):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
starts[index], ends[index], req_meta.block_ids, layer_id)
addr_list.append(addr)
size_list.append(size)
if self.dcp_size > 1:
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank %
self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
self.m_store.put(key_list, addr_list, size_list)
if layer_id == self.final_layer_id and is_last_chunk:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
total_block,
skip_block_num,
req_meta.req_id,
)
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event,
get_event: threading.Event):
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, get_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,