### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: fems14 <1804143737@qq.com>
332 lines
12 KiB
Python
332 lines
12 KiB
Python
import queue
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any
|
|
|
|
import torch
|
|
from vllm.logger import logger
|
|
|
|
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
|
|
|
# isort: off
|
|
from vllm_ascend.distributed.kvpool.config_data import (
|
|
ChunkedTokenDatabase,
|
|
LasyerMultiBlockReqMeta,
|
|
ReqMeta,
|
|
)
|
|
# isort: on
|
|
|
|
|
|
class KVTransferThread(threading.Thread):
|
|
|
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
|
block_size: int, tp_rank: int, dcp_size: int,
|
|
ready_event: threading.Event, name: str):
|
|
super().__init__(daemon=True, name=name)
|
|
self.m_store = m_store
|
|
self.ready_event = ready_event
|
|
self.block_size = block_size
|
|
self.tp_rank = tp_rank
|
|
self.dcp_size = dcp_size
|
|
self.token_database = token_database
|
|
self.done_task_lock = threading.Lock()
|
|
self.request_queue: queue.Queue[Any] = queue.Queue()
|
|
# TODO(jianzs): make this configurable
|
|
self.executor = ThreadPoolExecutor(max_workers=32)
|
|
self.finished_requests: set[str] = set()
|
|
|
|
def add_request(
|
|
self,
|
|
request: ReqMeta,
|
|
) -> torch.Tensor:
|
|
self.request_queue.put(request)
|
|
|
|
def get_and_clear_finished_requests(self) -> set[str]:
|
|
"""
|
|
Get and clear the requests that have been completed.
|
|
Returns:
|
|
A set of request IDs that have been completed.
|
|
"""
|
|
with self.done_task_lock:
|
|
finished_requests = self.finished_requests.copy()
|
|
self.finished_requests.clear()
|
|
return finished_requests
|
|
|
|
def set_finished_request(self, req_id):
|
|
with self.done_task_lock:
|
|
self.finished_requests.add(req_id)
|
|
|
|
def run(self):
|
|
"""Run the thread to handle KV cache transfer requests."""
|
|
self.m_store.set_device()
|
|
self.ready_event.set()
|
|
while True:
|
|
try:
|
|
request_data = self.request_queue.get()
|
|
if request_data is None:
|
|
logger.warning("Received a None request!")
|
|
self.request_queue.task_done()
|
|
continue
|
|
self._handle_request(request_data)
|
|
except Exception as e:
|
|
logger.error(f"Error in KVCacheTransferThread: {e}")
|
|
|
|
def _handle_request(self, req_meta: Any):
|
|
pass
|
|
|
|
def lookup(
|
|
self,
|
|
keys: list[str],
|
|
) -> int:
|
|
"""
|
|
Checks the existence of KV cache of the tokens from the cache engine.
|
|
:param tokens: the input tokens, with shape [seq_len]
|
|
:return: An int indicating how many prefix tokens are cached.
|
|
"""
|
|
try:
|
|
res = self.m_store.exists(keys) # type: ignore[assignment]
|
|
for index, value in enumerate(res): # type: ignore[arg-type]
|
|
if value != 1:
|
|
return index
|
|
# all tokens where found, return the maximal end
|
|
except Exception as e:
|
|
logger.error(f"Remote connection failed in contains: {e}")
|
|
return 0
|
|
return len(keys)
|
|
|
|
|
|
class KVCacheStoreSendingThread(KVTransferThread):
|
|
|
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
|
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
|
ready_event: threading.Event):
|
|
super().__init__(m_store,
|
|
token_database,
|
|
block_size,
|
|
tp_rank,
|
|
dcp_size,
|
|
ready_event,
|
|
name="KVCacheSendingThread")
|
|
self.put_step = put_step
|
|
|
|
def _handle_request(self, req_meta: ReqMeta):
|
|
token_len = req_meta.token_len_chunk
|
|
block_ids = req_meta.block_ids
|
|
req_id = req_meta.req_id
|
|
is_last_chunk = req_meta.is_last_chunk
|
|
starts = []
|
|
ends = []
|
|
keys = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
token_len, req_meta.block_hashes):
|
|
starts.append(start)
|
|
ends.append(end)
|
|
keys.append(key.to_string())
|
|
|
|
if not self.dcp_size > 1:
|
|
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
|
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
|
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
|
|
|
if not keys:
|
|
if is_last_chunk:
|
|
self.set_finished_request(req_id)
|
|
return
|
|
|
|
skip_block_num = self.lookup(keys)
|
|
|
|
if skip_block_num == len(keys):
|
|
if is_last_chunk:
|
|
self.set_finished_request(req_id)
|
|
return
|
|
|
|
starts = starts[skip_block_num:]
|
|
ends = ends[skip_block_num:]
|
|
keys = keys[skip_block_num:]
|
|
|
|
logger.info(
|
|
"Storing KV cache for %d out of %d blocks "
|
|
"(skip_block_num=%d) for request %s",
|
|
len(keys),
|
|
token_len // self.block_size,
|
|
skip_block_num,
|
|
req_id,
|
|
)
|
|
|
|
addrs = []
|
|
sizes = []
|
|
for index, start in enumerate(starts):
|
|
addr, size, _ = self.token_database.prepare_value(
|
|
start, ends[index], block_ids)
|
|
addrs.append(addr)
|
|
sizes.append(size)
|
|
if keys:
|
|
self.m_store.put(keys, addrs, sizes)
|
|
|
|
if is_last_chunk:
|
|
self.set_finished_request(req_id)
|
|
self.request_queue.task_done()
|
|
|
|
|
|
class KVCacheStoreRecvingThread(KVTransferThread):
|
|
|
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
|
block_size: int, tp_rank: int, dcp_size: int,
|
|
ready_event: threading.Event):
|
|
super().__init__(m_store,
|
|
token_database,
|
|
block_size,
|
|
tp_rank,
|
|
dcp_size,
|
|
ready_event,
|
|
name="KVCacheStoreRecvingThread")
|
|
|
|
def _handle_request(self, req_meta: ReqMeta):
|
|
req_id = req_meta.req_id
|
|
mask_num = (
|
|
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
|
// self.block_size * self.block_size)
|
|
addr_list = []
|
|
size_list = []
|
|
key_list = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
|
|
addr, size, _ = self.token_database.prepare_value(
|
|
start, end, req_meta.block_ids)
|
|
key_list.append(key.to_string())
|
|
addr_list.append(addr)
|
|
size_list.append(size)
|
|
key_list_c = key_list[self.tp_rank %
|
|
len(key_list):] + key_list[:self.tp_rank %
|
|
len(key_list)]
|
|
addr_list_c = addr_list[self.tp_rank %
|
|
len(addr_list):] + addr_list[:self.tp_rank %
|
|
len(addr_list)]
|
|
size_list_c = size_list[self.tp_rank %
|
|
len(size_list):] + size_list[:self.tp_rank %
|
|
len(size_list)]
|
|
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
|
self.set_finished_request(req_id)
|
|
self.request_queue.task_done()
|
|
|
|
|
|
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|
|
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
|
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
|
ready_event: threading.Event, num_layers: int):
|
|
super().__init__(m_store,
|
|
token_database,
|
|
block_size,
|
|
tp_rank,
|
|
dcp_size,
|
|
ready_event,
|
|
name="KVCacheStoreLayerSendingThread")
|
|
self.final_layer_id = num_layers - 1
|
|
self.put_step = put_step
|
|
|
|
def add_request( # type: ignore[override]
|
|
self, req_meta: ReqMeta) -> torch.Tensor:
|
|
self.request_queue.put(req_meta)
|
|
|
|
def _handle_request( # type: ignore[override]
|
|
self, req_meta: LasyerMultiBlockReqMeta):
|
|
starts = req_meta.starts
|
|
ends = req_meta.ends
|
|
keys = req_meta.keys
|
|
layer_id = req_meta.layer_id
|
|
total_block = len(keys)
|
|
is_last_chunk = req_meta.is_last_chunk
|
|
if not self.dcp_size > 1:
|
|
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
|
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
|
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
|
|
|
if not keys:
|
|
if is_last_chunk:
|
|
self.set_finished_request(req_meta.req_id)
|
|
return
|
|
|
|
key_list = []
|
|
for key in keys:
|
|
key_list.append(key.to_string())
|
|
|
|
skip_block_num = self.lookup(key_list)
|
|
|
|
if skip_block_num == len(key_list):
|
|
if is_last_chunk and layer_id == self.final_layer_id:
|
|
self.set_finished_request(req_meta.req_id)
|
|
return
|
|
|
|
starts = starts[skip_block_num:]
|
|
ends = ends[skip_block_num:]
|
|
key_list = key_list[skip_block_num:]
|
|
|
|
addr_list = []
|
|
size_list = []
|
|
for index, key in enumerate(key_list):
|
|
addr, size = self.token_database.prepare_value_layer(
|
|
starts[index], ends[index], req_meta.block_ids, layer_id)
|
|
addr_list.append(addr)
|
|
size_list.append(size)
|
|
|
|
self.m_store.put(key_list, addr_list, size_list)
|
|
|
|
if layer_id == self.final_layer_id and is_last_chunk:
|
|
self.set_finished_request(req_meta.req_id)
|
|
self.request_queue.task_done()
|
|
|
|
logger.info(
|
|
"Storing KV cache for %d out of %d blocks "
|
|
"(skip_block_num=%d) for request %s",
|
|
len(keys),
|
|
total_block,
|
|
skip_block_num,
|
|
req_meta.req_id,
|
|
)
|
|
|
|
|
|
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
|
|
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
|
block_size: int, tp_rank: int, dcp_size: int,
|
|
ready_event: threading.Event, get_event: threading.Event):
|
|
super().__init__(m_store,
|
|
token_database,
|
|
block_size,
|
|
tp_rank,
|
|
dcp_size,
|
|
ready_event,
|
|
name="KVCacheStoreLayerRecvingThread")
|
|
self.get_event = get_event
|
|
|
|
def add_request( # type: ignore[override]
|
|
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
|
self.request_queue.put(req_meta)
|
|
|
|
def _handle_request( # type: ignore[override]
|
|
self, req_meta: LasyerMultiBlockReqMeta):
|
|
addr_list = []
|
|
size_list = []
|
|
key_list = []
|
|
for index, key in enumerate(req_meta.keys):
|
|
addr, size = self.token_database.prepare_value_layer(
|
|
req_meta.starts[index], req_meta.ends[index],
|
|
req_meta.block_ids, req_meta.layer_id)
|
|
key_list.append(key.to_string())
|
|
addr_list.append(addr)
|
|
size_list.append(size)
|
|
key_list_c = key_list[self.tp_rank %
|
|
len(key_list):] + key_list[:self.tp_rank %
|
|
len(key_list)]
|
|
addr_list_c = addr_list[self.tp_rank %
|
|
len(addr_list):] + addr_list[:self.tp_rank %
|
|
len(addr_list)]
|
|
size_list_c = size_list[self.tp_rank %
|
|
len(size_list):] + size_list[:self.tp_rank %
|
|
len(size_list)]
|
|
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
|
|
|
self.request_queue.task_done()
|
|
self.get_event.set()
|