[bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030)
### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
|
||||
MooncakeBackend
|
||||
from vllm_ascend.distributed.kvpool.config_data import (
|
||||
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
||||
LasyerMultiBlockReqMeta)
|
||||
LasyerMultiBlockReqMeta, ReqMeta)
|
||||
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
@@ -165,28 +165,29 @@ class KVPoolWorker:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.dcp_size, self.put_step, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
ready_event_sending, self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
|
||||
ready_event, self.get_event)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.dcp_size, self.put_step, ready_event_sending)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.dcp_size, ready_event)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
@@ -198,38 +199,27 @@ class KVPoolWorker:
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.kvpool_cached_tokens
|
||||
== token_len - 1):
|
||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||
else:
|
||||
token_len = request.load_spec.kvpool_cached_tokens
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
request.token_len_chunk = token_len
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
)
|
||||
layerwise_retriever = self.retrieve_layer(request)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
)
|
||||
request, )
|
||||
else:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
@@ -266,40 +256,7 @@ class KVPoolWorker:
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = self.lookup(token_len,
|
||||
request.block_hashes,
|
||||
self.use_layerwise)
|
||||
if skip_leading_tokens == token_len:
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
mask_num = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
token_len - skip_leading_tokens,
|
||||
token_len,
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_len,
|
||||
block_hashes=request.block_hashes,
|
||||
mask_num=mask_num,
|
||||
block_ids=request.block_ids,
|
||||
is_last_chunk=request.is_last_chunk,
|
||||
)
|
||||
layerwise_storer = self.store_layer(request)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
@@ -314,45 +271,12 @@ class KVPoolWorker:
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
|
||||
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
|
||||
self.use_layerwise)
|
||||
if skip_leading_tokens == token_len:
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
mask_num = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
token_len - skip_leading_tokens,
|
||||
token_len,
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
request.is_last_chunk,
|
||||
)
|
||||
request, )
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
block_hashes: list[BlockHash],
|
||||
mask_num: int = 0,
|
||||
request: ReqMeta,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
@@ -370,6 +294,10 @@ class KVPoolWorker:
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
token_len = request.token_len_chunk
|
||||
mask_num = (
|
||||
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
num_required_tokens = token_len - mask_num
|
||||
|
||||
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
||||
@@ -379,7 +307,7 @@ class KVPoolWorker:
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
token_len, request.block_hashes, mask_num):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -395,8 +323,9 @@ class KVPoolWorker:
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
@@ -417,12 +346,7 @@ class KVPoolWorker:
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
block_hashes: list[BlockHash],
|
||||
is_last_chunk: bool,
|
||||
mask_num: int = 0,
|
||||
request: ReqMeta,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
@@ -444,13 +368,11 @@ class KVPoolWorker:
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
num_stored_tokens = token_len - mask_num
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
request.token_len_chunk, request.block_hashes):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -459,17 +381,17 @@ class KVPoolWorker:
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id, is_last_chunk)
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id,
|
||||
request.is_last_chunk)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {token_len} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
|
||||
Reference in New Issue
Block a user