[bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030)

### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-12-15 21:56:05 +08:00
committed by GitHub
parent c064d11fd7
commit b662d914a4
5 changed files with 188 additions and 199 deletions

View File

@@ -18,7 +18,7 @@ from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta)
LasyerMultiBlockReqMeta, ReqMeta)
from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
@@ -165,28 +165,29 @@ class KVPoolWorker:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.dcp_size, self.put_step, ready_event_sending,
self.num_layers)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending, self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
ready_event, self.get_event)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.dcp_size, self.put_step, ready_event_sending)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.tp_rank,
self.dcp_size, ready_event)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()
@@ -198,38 +199,27 @@ class KVPoolWorker:
if load_spec is None or not load_spec.can_load: #load =0
continue
token_len = request.token_len_chunk
req_id = request.req_id
if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens
== token_len - 1):
token_len = request.load_spec.kvpool_cached_tokens + 1
else:
token_len = request.load_spec.kvpool_cached_tokens
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
request.token_len_chunk = token_len
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
layerwise_retriever = self.retrieve_layer(request)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
request, )
else:
addr_list = []
size_list = []
key_list = []
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
@@ -266,40 +256,7 @@ class KVPoolWorker:
if can_save is None or not can_save:
continue
token_len = request.token_len_chunk
req_id = request.req_id
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = self.lookup(token_len,
request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_len,
block_hashes=request.block_hashes,
mask_num=mask_num,
block_ids=request.block_ids,
is_last_chunk=request.is_last_chunk,
)
layerwise_storer = self.store_layer(request)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
@@ -314,45 +271,12 @@ class KVPoolWorker:
if can_save is None or not can_save:
continue
token_len = request.token_len_chunk
req_id = request.req_id
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
request.is_last_chunk,
)
request, )
def retrieve_layer(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
request: ReqMeta,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
@@ -370,6 +294,10 @@ class KVPoolWorker:
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
token_len = request.token_len_chunk
mask_num = (
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
num_required_tokens = token_len - mask_num
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
@@ -379,7 +307,7 @@ class KVPoolWorker:
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
token_len, request.block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -395,8 +323,9 @@ class KVPoolWorker:
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
@@ -417,12 +346,7 @@ class KVPoolWorker:
def store_layer(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
is_last_chunk: bool,
mask_num: int = 0,
request: ReqMeta,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
@@ -444,13 +368,11 @@ class KVPoolWorker:
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
num_stored_tokens = token_len - mask_num
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
request.token_len_chunk, request.block_hashes):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -459,17 +381,17 @@ class KVPoolWorker:
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id, is_last_chunk)
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id,
request.is_last_chunk)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {token_len} tokens")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (