[bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030)
### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -19,7 +19,7 @@ class MemcacheBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from memcache import DistributedObjectStore # type: ignore
|
||||
from memcache_hybrid import DistributedObjectStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install memcache by following the instructions at "
|
||||
@@ -43,10 +43,7 @@ class MemcacheBackend(Backend):
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
for ptr, size in zip(ptrs, sizes):
|
||||
ret_value = self.store.register_buffer(ptr, size)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Memcache memory registration failed.")
|
||||
pass
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
@@ -374,4 +374,4 @@ class LasyerMultiBlockReqMeta:
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: bool = True
|
||||
is_last_chunk: Optional[bool] = True
|
||||
|
||||
@@ -1,29 +1,31 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
|
||||
LasyerMultiBlockReqMeta
|
||||
)
|
||||
from vllm_ascend.distributed.kvpool.config_data import (
|
||||
ChunkedTokenDatabase,
|
||||
LasyerMultiBlockReqMeta,
|
||||
ReqMeta,
|
||||
)
|
||||
# isort: on
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, dcp_size: int, ready_event: threading.Event,
|
||||
name: str):
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.block_size = block_size
|
||||
self.tp_rank = tp_rank
|
||||
self.dcp_size = dcp_size
|
||||
self.token_database = token_database
|
||||
@@ -35,22 +37,9 @@ class KVTransferThread(threading.Thread):
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
block_hashes: list[BlockHash],
|
||||
mask_num: int = 0,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
request: ReqMeta,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"token_len": token_len,
|
||||
"block_ids": block_ids,
|
||||
"block_hashes": block_hashes,
|
||||
"mask_num": mask_num,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
self.request_queue.put(request)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
@@ -82,50 +71,98 @@ class KVTransferThread(threading.Thread):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
def _handle_request(self, req_meta: Any):
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
keys: list[str],
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
try:
|
||||
res = self.m_store.exists(keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return index
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return 0
|
||||
return len(keys)
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, dcp_size: int, put_step: int,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
self.put_step = put_step
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
token_len = req_meta["token_len"]
|
||||
mask_num = req_meta["mask_num"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
block_hashes = req_meta["block_hashes"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
def _handle_request(self, req_meta: ReqMeta):
|
||||
token_len = req_meta.token_len_chunk
|
||||
block_ids = req_meta.block_ids
|
||||
req_id = req_meta.req_id
|
||||
is_last_chunk = req_meta.is_last_chunk
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
token_len, req_meta.block_hashes):
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(key.to_string())
|
||||
|
||||
if not self.dcp_size > 1:
|
||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||
|
||||
if not keys:
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
return
|
||||
|
||||
skip_block_num = self.lookup(keys)
|
||||
|
||||
if skip_block_num == len(keys):
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
return
|
||||
|
||||
starts = starts[skip_block_num:]
|
||||
ends = ends[skip_block_num:]
|
||||
keys = keys[skip_block_num:]
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d blocks "
|
||||
"(skip_block_num=%d) for request %s",
|
||||
len(keys),
|
||||
token_len // self.block_size,
|
||||
skip_block_num,
|
||||
req_id,
|
||||
)
|
||||
|
||||
addrs = []
|
||||
sizes = []
|
||||
for index, start in enumerate(starts):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
if self.dcp_size > 1:
|
||||
self.m_store.put(key_list, addr_list, size_list)
|
||||
else:
|
||||
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
||||
addr_list_tp = addr_list[self.tp_rank %
|
||||
self.put_step::self.put_step]
|
||||
size_list_tp = size_list[self.tp_rank %
|
||||
self.put_step::self.put_step]
|
||||
if key_list_tp:
|
||||
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
||||
start, ends[index], block_ids)
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
if keys:
|
||||
self.m_store.put(keys, addrs, sizes)
|
||||
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
@@ -134,27 +171,28 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, dcp_size: int, ready_event: threading.Event):
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
token_len = req_meta["token_len"]
|
||||
mask_num = req_meta["mask_num"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
block_hashes = req_meta["block_hashes"]
|
||||
def _handle_request(self, req_meta: ReqMeta):
|
||||
req_id = req_meta.req_id
|
||||
mask_num = (
|
||||
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, block_ids)
|
||||
start, end, req_meta.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
@@ -175,10 +213,11 @@ class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, dcp_size: int, put_step: int,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
ready_event: threading.Event, num_layers: int):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
@@ -187,43 +226,74 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
self.put_step = put_step
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self, req_meta: ReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
starts = req_meta.starts
|
||||
ends = req_meta.ends
|
||||
keys = req_meta.keys
|
||||
layer_id = req_meta.layer_id
|
||||
total_block = len(keys)
|
||||
is_last_chunk = req_meta.is_last_chunk
|
||||
if not self.dcp_size > 1:
|
||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||
|
||||
if not keys:
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
return
|
||||
|
||||
key_list = []
|
||||
for key in keys:
|
||||
key_list.append(key.to_string())
|
||||
|
||||
skip_block_num = self.lookup(key_list)
|
||||
|
||||
if skip_block_num == len(key_list):
|
||||
if is_last_chunk and layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
return
|
||||
|
||||
starts = starts[skip_block_num:]
|
||||
ends = ends[skip_block_num:]
|
||||
key_list = key_list[skip_block_num:]
|
||||
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
for index, key in enumerate(key_list):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
req_meta.starts[index], req_meta.ends[index],
|
||||
req_meta.block_ids, req_meta.layer_id)
|
||||
key_list.append(key.to_string())
|
||||
starts[index], ends[index], req_meta.block_ids, layer_id)
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
if self.dcp_size > 1:
|
||||
self.m_store.put(key_list, addr_list, size_list)
|
||||
else:
|
||||
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
||||
addr_list_tp = addr_list[self.tp_rank %
|
||||
self.put_step::self.put_step]
|
||||
size_list_tp = size_list[self.tp_rank %
|
||||
self.put_step::self.put_step]
|
||||
if key_list_tp:
|
||||
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
||||
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
|
||||
|
||||
self.m_store.put(key_list, addr_list, size_list)
|
||||
|
||||
if layer_id == self.final_layer_id and is_last_chunk:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d blocks "
|
||||
"(skip_block_num=%d) for request %s",
|
||||
len(keys),
|
||||
total_block,
|
||||
skip_block_num,
|
||||
req_meta.req_id,
|
||||
)
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, dcp_size: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event, get_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
|
||||
@@ -310,8 +310,8 @@ class LookupKeyClient:
|
||||
self.socket.close(linger=0)
|
||||
|
||||
|
||||
def get_zmq_rpc_path_lookup(
|
||||
vllm_config: Optional["VllmConfig"] = None, ) -> str:
|
||||
def get_zmq_rpc_path_lookup(vllm_config: "VllmConfig") -> str:
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
@@ -325,4 +325,4 @@ def get_zmq_rpc_path_lookup(
|
||||
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
|
||||
)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}"
|
||||
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}_dp_rank{dp_rank}"
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
|
||||
MooncakeBackend
|
||||
from vllm_ascend.distributed.kvpool.config_data import (
|
||||
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
||||
LasyerMultiBlockReqMeta)
|
||||
LasyerMultiBlockReqMeta, ReqMeta)
|
||||
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
@@ -165,28 +165,29 @@ class KVPoolWorker:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.dcp_size, self.put_step, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
ready_event_sending, self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
|
||||
ready_event, self.get_event)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.dcp_size, self.put_step, ready_event_sending)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.dcp_size, ready_event)
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
@@ -198,38 +199,27 @@ class KVPoolWorker:
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.kvpool_cached_tokens
|
||||
== token_len - 1):
|
||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||
else:
|
||||
token_len = request.load_spec.kvpool_cached_tokens
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
request.token_len_chunk = token_len
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
)
|
||||
layerwise_retriever = self.retrieve_layer(request)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
)
|
||||
request, )
|
||||
else:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
@@ -266,40 +256,7 @@ class KVPoolWorker:
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = self.lookup(token_len,
|
||||
request.block_hashes,
|
||||
self.use_layerwise)
|
||||
if skip_leading_tokens == token_len:
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
mask_num = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
token_len - skip_leading_tokens,
|
||||
token_len,
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_len,
|
||||
block_hashes=request.block_hashes,
|
||||
mask_num=mask_num,
|
||||
block_ids=request.block_ids,
|
||||
is_last_chunk=request.is_last_chunk,
|
||||
)
|
||||
layerwise_storer = self.store_layer(request)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
@@ -314,45 +271,12 @@ class KVPoolWorker:
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
|
||||
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
|
||||
self.use_layerwise)
|
||||
if skip_leading_tokens == token_len:
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
mask_num = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
token_len - skip_leading_tokens,
|
||||
token_len,
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
request.is_last_chunk,
|
||||
)
|
||||
request, )
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
block_hashes: list[BlockHash],
|
||||
mask_num: int = 0,
|
||||
request: ReqMeta,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
@@ -370,6 +294,10 @@ class KVPoolWorker:
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
token_len = request.token_len_chunk
|
||||
mask_num = (
|
||||
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
num_required_tokens = token_len - mask_num
|
||||
|
||||
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
||||
@@ -379,7 +307,7 @@ class KVPoolWorker:
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
token_len, request.block_hashes, mask_num):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -395,8 +323,9 @@ class KVPoolWorker:
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
@@ -417,12 +346,7 @@ class KVPoolWorker:
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
block_hashes: list[BlockHash],
|
||||
is_last_chunk: bool,
|
||||
mask_num: int = 0,
|
||||
request: ReqMeta,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
@@ -444,13 +368,11 @@ class KVPoolWorker:
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
num_stored_tokens = token_len - mask_num
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
request.token_len_chunk, request.block_hashes):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -459,17 +381,17 @@ class KVPoolWorker:
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id, is_last_chunk)
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id,
|
||||
request.is_last_chunk)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {token_len} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
|
||||
Reference in New Issue
Block a user