[bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030)

### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-12-15 21:56:05 +08:00
committed by GitHub
parent c064d11fd7
commit b662d914a4
5 changed files with 188 additions and 199 deletions

View File

@@ -19,7 +19,7 @@ class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig): def __init__(self, parallel_config: ParallelConfig):
try: try:
from memcache import DistributedObjectStore # type: ignore from memcache_hybrid import DistributedObjectStore # type: ignore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please install memcache by following the instructions at " "Please install memcache by following the instructions at "
@@ -43,10 +43,7 @@ class MemcacheBackend(Backend):
torch.npu.set_device(device) torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], sizes: list[int]): def register_buffer(self, ptrs: list[int], sizes: list[int]):
for ptr, size in zip(ptrs, sizes): pass
ret_value = self.store.register_buffer(ptr, size)
if ret_value != 0:
raise RuntimeError("Memcache memory registration failed.")
def exists(self, keys: list[str]) -> list[int]: def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys) return self.store.batch_is_exist(keys)

View File

@@ -374,4 +374,4 @@ class LasyerMultiBlockReqMeta:
ends: list[int] ends: list[int]
block_ids: list[int] block_ids: list[int]
layer_id: int layer_id: int
is_last_chunk: bool = True is_last_chunk: Optional[bool] = True

View File

@@ -1,29 +1,31 @@
import queue import queue
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional from typing import Any
import torch import torch
from vllm.logger import logger from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kvpool.backend.backend import Backend from vllm_ascend.distributed.kvpool.backend.backend import Backend
# isort: off # isort: off
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase, from vllm_ascend.distributed.kvpool.config_data import (
LasyerMultiBlockReqMeta ChunkedTokenDatabase,
) LasyerMultiBlockReqMeta,
ReqMeta,
)
# isort: on # isort: on
class KVTransferThread(threading.Thread): class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event, block_size: int, tp_rank: int, dcp_size: int,
name: str): ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name) super().__init__(daemon=True, name=name)
self.m_store = m_store self.m_store = m_store
self.ready_event = ready_event self.ready_event = ready_event
self.block_size = block_size
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.dcp_size = dcp_size self.dcp_size = dcp_size
self.token_database = token_database self.token_database = token_database
@@ -35,22 +37,9 @@ class KVTransferThread(threading.Thread):
def add_request( def add_request(
self, self,
req_id: str, request: ReqMeta,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
is_last_chunk: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
req = ({ self.request_queue.put(request)
"req_id": req_id,
"token_len": token_len,
"block_ids": block_ids,
"block_hashes": block_hashes,
"mask_num": mask_num,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]: def get_and_clear_finished_requests(self) -> set[str]:
""" """
@@ -82,50 +71,98 @@ class KVTransferThread(threading.Thread):
except Exception as e: except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}") logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]): def _handle_request(self, req_meta: Any):
pass pass
def lookup(
self,
keys: list[str],
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
try:
res = self.m_store.exists(keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return index
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return 0
return len(keys)
class KVCacheStoreSendingThread(KVTransferThread): class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, put_step: int, block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event): ready_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
block_size,
tp_rank, tp_rank,
dcp_size, dcp_size,
ready_event, ready_event,
name="KVCacheSendingThread") name="KVCacheSendingThread")
self.put_step = put_step self.put_step = put_step
def _handle_request(self, req_meta: dict[str, Any]): def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta["token_len"] token_len = req_meta.token_len_chunk
mask_num = req_meta["mask_num"] block_ids = req_meta.block_ids
block_ids = req_meta["block_ids"] req_id = req_meta.req_id
block_hashes = req_meta["block_hashes"] is_last_chunk = req_meta.is_last_chunk
req_id = req_meta["req_id"] starts = []
is_last_chunk = req_meta["is_last_chunk"] ends = []
addr_list = [] keys = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num): token_len, req_meta.block_hashes):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
if not keys:
if is_last_chunk:
self.set_finished_request(req_id)
return
skip_block_num = self.lookup(keys)
if skip_block_num == len(keys):
if is_last_chunk:
self.set_finished_request(req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
req_id,
)
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value( addr, size, _ = self.token_database.prepare_value(
start, end, block_ids) start, ends[index], block_ids)
key_list.append(key.to_string()) addrs.append(addr)
addr_list.append(addr) sizes.append(size)
size_list.append(size) if keys:
if self.dcp_size > 1: self.m_store.put(keys, addrs, sizes)
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank %
self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if is_last_chunk: if is_last_chunk:
self.set_finished_request(req_id) self.set_finished_request(req_id)
self.request_queue.task_done() self.request_queue.task_done()
@@ -134,27 +171,28 @@ class KVCacheStoreSendingThread(KVTransferThread):
class KVCacheStoreRecvingThread(KVTransferThread): class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event): block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
block_size,
tp_rank, tp_rank,
dcp_size, dcp_size,
ready_event, ready_event,
name="KVCacheStoreRecvingThread") name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]): def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta["token_len"] req_id = req_meta.req_id
mask_num = req_meta["mask_num"] mask_num = (
block_ids = req_meta["block_ids"] req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
req_id = req_meta["req_id"] // self.block_size * self.block_size)
block_hashes = req_meta["block_hashes"]
addr_list = [] addr_list = []
size_list = [] size_list = []
key_list = [] key_list = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num): req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value( addr, size, _ = self.token_database.prepare_value(
start, end, block_ids) start, end, req_meta.block_ids)
key_list.append(key.to_string()) key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
@@ -175,10 +213,11 @@ class KVCacheStoreRecvingThread(KVTransferThread):
class KVCacheStoreLayerSendingThread(KVTransferThread): class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, put_step: int, block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event, num_layers: int): ready_event: threading.Event, num_layers: int):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
block_size,
tp_rank, tp_rank,
dcp_size, dcp_size,
ready_event, ready_event,
@@ -187,43 +226,74 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.put_step = put_step self.put_step = put_step
def add_request( # type: ignore[override] def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self, req_meta: ReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta) self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override] def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta): self, req_meta: LasyerMultiBlockReqMeta):
starts = req_meta.starts
ends = req_meta.ends
keys = req_meta.keys
layer_id = req_meta.layer_id
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
if not keys:
if is_last_chunk:
self.set_finished_request(req_meta.req_id)
return
key_list = []
for key in keys:
key_list.append(key.to_string())
skip_block_num = self.lookup(key_list)
if skip_block_num == len(key_list):
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]
addr_list = [] addr_list = []
size_list = [] size_list = []
key_list = [] for index, key in enumerate(key_list):
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer( addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index], starts[index], ends[index], req_meta.block_ids, layer_id)
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
if self.dcp_size > 1:
self.m_store.put(key_list, addr_list, size_list) self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] if layer_id == self.final_layer_id and is_last_chunk:
addr_list_tp = addr_list[self.tp_rank %
self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
self.set_finished_request(req_meta.req_id) self.set_finished_request(req_meta.req_id)
self.request_queue.task_done() self.request_queue.task_done()
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
total_block,
skip_block_num,
req_meta.req_id,
)
class KVCacheStoreLayerRecvingThread(KVTransferThread): class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event, block_size: int, tp_rank: int, dcp_size: int,
get_event: threading.Event): ready_event: threading.Event, get_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
block_size,
tp_rank, tp_rank,
dcp_size, dcp_size,
ready_event, ready_event,

View File

@@ -310,8 +310,8 @@ class LookupKeyClient:
self.socket.close(linger=0) self.socket.close(linger=0)
def get_zmq_rpc_path_lookup( def get_zmq_rpc_path_lookup(vllm_config: "VllmConfig") -> str:
vllm_config: Optional["VllmConfig"] = None, ) -> str: dp_rank = vllm_config.parallel_config.data_parallel_rank
base_url = envs.VLLM_RPC_BASE_PATH base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured # Default to 0 if not configured
rpc_port = 0 rpc_port = 0
@@ -325,4 +325,4 @@ def get_zmq_rpc_path_lookup(
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future." "It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
) )
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}" return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}_dp_rank{dp_rank}"

View File

@@ -18,7 +18,7 @@ from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
MooncakeBackend MooncakeBackend
from vllm_ascend.distributed.kvpool.config_data import ( from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta) LasyerMultiBlockReqMeta, ReqMeta)
from vllm_ascend.distributed.kvpool.kv_transfer import ( from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
@@ -165,28 +165,29 @@ class KVPoolWorker:
if self.kv_role in ['kv_producer', 'kv_both']: if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread( self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.tp_rank, self.m_store, self.token_database, self.block_size,
self.dcp_size, self.put_step, ready_event_sending, self.tp_rank, self.dcp_size, self.put_step,
self.num_layers) ready_event_sending, self.num_layers)
self.kv_send_thread.start() self.kv_send_thread.start()
ready_event = threading.Event() ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.tp_rank, self.dcp_size, self.m_store, self.token_database, self.block_size,
ready_event, self.get_event) self.tp_rank, self.dcp_size, ready_event, self.get_event)
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()
else: else:
if self.kv_role in ['kv_producer', 'kv_both']: if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread( self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.tp_rank, self.m_store, self.token_database, self.block_size,
self.dcp_size, self.put_step, ready_event_sending) self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending)
self.kv_send_thread.start() self.kv_send_thread.start()
if self.load_async: if self.load_async:
ready_event = threading.Event() ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread( self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.tp_rank, self.m_store, self.token_database, self.block_size,
self.dcp_size, ready_event) self.tp_rank, self.dcp_size, ready_event)
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()
@@ -198,38 +199,27 @@ class KVPoolWorker:
if load_spec is None or not load_spec.can_load: #load =0 if load_spec is None or not load_spec.can_load: #load =0
continue continue
token_len = request.token_len_chunk token_len = request.token_len_chunk
req_id = request.req_id
if (load_spec.kvpool_cached_tokens % self.block_size if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens != 0) and (load_spec.kvpool_cached_tokens
== token_len - 1): == token_len - 1):
token_len = request.load_spec.kvpool_cached_tokens + 1 token_len = request.load_spec.kvpool_cached_tokens + 1
else: else:
token_len = request.load_spec.kvpool_cached_tokens token_len = request.load_spec.kvpool_cached_tokens
mask_num = (request.load_spec.vllm_cached_tokens // request.token_len_chunk = token_len
self.block_size * self.block_size)
if self.use_layerwise: if self.use_layerwise:
layerwise_retriever = self.retrieve_layer( layerwise_retriever = self.retrieve_layer(request)
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
next(layerwise_retriever) # first layer load next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever) self.layerwise_retrievers.append(layerwise_retriever)
else: else:
if self.load_async: if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr] self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id, request, )
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
else: else:
addr_list = [] addr_list = []
size_list = [] size_list = []
key_list = [] key_list = []
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num): token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value( addr, size, _ = self.token_database.prepare_value(
@@ -266,40 +256,7 @@ class KVPoolWorker:
if can_save is None or not can_save: if can_save is None or not can_save:
continue continue
token_len = request.token_len_chunk layerwise_storer = self.store_layer(request)
req_id = request.req_id
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = self.lookup(token_len,
request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_len,
block_hashes=request.block_hashes,
mask_num=mask_num,
block_ids=request.block_ids,
is_last_chunk=request.is_last_chunk,
)
self.layerwise_storers.append(layerwise_storer) self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers: for layerwise_storer in self.layerwise_storers:
try: try:
@@ -314,45 +271,12 @@ class KVPoolWorker:
if can_save is None or not can_save: if can_save is None or not can_save:
continue continue
token_len = request.token_len_chunk
req_id = request.req_id
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr] self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id, request, )
token_len,
request.block_ids,
request.block_hashes,
mask_num,
request.is_last_chunk,
)
def retrieve_layer( def retrieve_layer(
self, self,
req_id: str, request: ReqMeta,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
) -> Generator[Optional[torch.Tensor], None, None]: ) -> Generator[Optional[torch.Tensor], None, None]:
""" """
Retrieve the KV cache in a layerwise manner. Retrieve the KV cache in a layerwise manner.
@@ -370,6 +294,10 @@ class KVPoolWorker:
be the boolean mask indicating which tokens are retrieved and will be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration. only be returned in the last iteration.
""" """
token_len = request.token_len_chunk
mask_num = (
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
num_required_tokens = token_len - mask_num num_required_tokens = token_len - mask_num
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu") ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
@@ -379,7 +307,7 @@ class KVPoolWorker:
keys = [] keys = []
first_flag = True first_flag = True
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num): token_len, request.block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers) keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
@@ -395,8 +323,9 @@ class KVPoolWorker:
if not is_finish: if not is_finish:
logger.info("Layerwise get failed") logger.info("Layerwise get failed")
self.get_event.clear() self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, req_meta = LasyerMultiBlockReqMeta(request.req_id,
starts, ends, block_ids, keys_multi_chunk, starts,
ends, request.block_ids,
layer_id) layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type] req_meta) # type: ignore[union-attr, call-arg, arg-type]
@@ -417,12 +346,7 @@ class KVPoolWorker:
def store_layer( def store_layer(
self, self,
req_id: str, request: ReqMeta,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
is_last_chunk: bool,
mask_num: int = 0,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
""" """
Store the KV cache in a layerwise manner. Store the KV cache in a layerwise manner.
@@ -444,13 +368,11 @@ class KVPoolWorker:
storage backends. In the last iteration, it puts the memory objects storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends. of the last layer to the storage backends.
""" """
num_stored_tokens = token_len - mask_num
starts = [] starts = []
ends = [] ends = []
keys = [] keys = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num): request.token_len_chunk, request.block_hashes):
keys_multi_layer = key.split_layers(self.num_layers) keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
@@ -459,17 +381,17 @@ class KVPoolWorker:
if keys: if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys): for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, req_meta = LasyerMultiBlockReqMeta(request.req_id,
starts, ends, block_ids, keys_multi_chunk, starts,
layer_id, is_last_chunk) ends, request.block_ids,
layer_id,
request.is_last_chunk)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type] req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield yield
else: else:
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
yield yield
logger.debug(
f"Stored {num_stored_tokens} out of total {token_len} tokens")
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = ( done_sending = (