### What this PR does / why we need it?
In the current KV Pool scenario for models like MLA and GQA, where
different TP ranks generate identical KV caches, the system is designed
to store only a single copy. The previous approach allowed each card to
query storage requirements dynamically, but inconsistent query results
across cards led to incorrect storage. To fix this, the new solution
pre-allocates storage responsibilities; each card now simply stores its
pre-assigned blocks, bypassing the inconsistent query step and ensuring
data correctness.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: fems14 <1804143737@qq.com>
528 lines
22 KiB
Python
528 lines
22 KiB
Python
import math
|
|
import threading
|
|
from typing import Dict, Generator, Optional, Type
|
|
|
|
import torch
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
|
get_decode_context_model_parallel_world_size,
|
|
get_pcp_group, get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size)
|
|
from vllm.logger import logger
|
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
|
|
|
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
|
from vllm_ascend.distributed.kvpool.backend.memcache_backend import \
|
|
MemcacheBackend
|
|
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
|
|
MooncakeBackend
|
|
from vllm_ascend.distributed.kvpool.config_data import (
|
|
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
|
LasyerMultiBlockReqMeta, ReqMeta)
|
|
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
|
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
|
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
|
|
|
backend_map: Dict[str, Type[Backend]] = {
|
|
"mooncake": MooncakeBackend,
|
|
"memcache": MemcacheBackend,
|
|
}
|
|
|
|
|
|
class KVPoolWorker:
|
|
#The main class for the cache engine.
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
use_layerwize: bool,
|
|
):
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
self.dp_rank = parallel_config.data_parallel_rank
|
|
self.use_mla = False
|
|
if (hasattr(model_config, "use_mla")
|
|
and isinstance(model_config.use_mla, bool)
|
|
and model_config.use_mla):
|
|
self.use_mla = True
|
|
self.use_layerwise = use_layerwize
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.pp_size = parallel_config.pipeline_parallel_size
|
|
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
|
|
|
|
self.pcp_size = get_pcp_group().world_size
|
|
self.pcp_rank = get_pcp_group(
|
|
).rank_in_group if self.pcp_size > 1 else 0
|
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
|
) if self.dcp_size > 1 else 0
|
|
|
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
|
"load_async", False)
|
|
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
|
"backend", "mooncake")
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
|
|
if self.pcp_size > 1:
|
|
self.block_size *= self.pcp_size
|
|
if self.dcp_size > 1:
|
|
self.block_size *= self.dcp_size
|
|
self.current_layer = 0
|
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
|
|
|
if self.use_mla:
|
|
self.num_kv_head = 1
|
|
else:
|
|
self.num_kv_head = model_config.get_total_num_kv_heads()
|
|
|
|
if self.num_kv_head < self.tp_size:
|
|
self.put_step = self.tp_size // self.num_kv_head
|
|
self.head_or_tp_rank = self.tp_rank // self.put_step
|
|
else:
|
|
self.head_or_tp_rank = self.tp_rank
|
|
self.put_step = 1
|
|
|
|
self.metadata = KeyMetadata(
|
|
model_config.model.split('/')[-1],
|
|
self.head_or_tp_rank,
|
|
self.pcp_rank,
|
|
self.dcp_rank,
|
|
self.pp_rank,
|
|
)
|
|
|
|
self.token_database = ChunkedTokenDatabase(self.metadata,
|
|
self.block_size,
|
|
self.use_mla)
|
|
|
|
real_backend = backend_map.get(self.backend.lower())
|
|
self.m_store = real_backend( # type: ignore[misc]
|
|
parallel_config)
|
|
|
|
self.kv_send_thread: Optional[KVTransferThread] = None
|
|
self.kv_recv_thread: Optional[KVTransferThread] = None
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
|
first_kv_cache = first_kv_cache_tuple[0]
|
|
|
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
|
if self.use_mla:
|
|
# MLA case.[num_block, block_size, 1, hidden_dim]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
block_rank = 3 # [block_size, latent_dim]
|
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
|
self.block_len = [
|
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
|
]
|
|
logger.info(
|
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
|
self.num_blocks, block_shape_norm, block_shape_pe)
|
|
else:
|
|
# [num_block, block_size, num_head, hidden_dim]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
kv_elem_size = first_kv_cache.element_size()
|
|
block_rank = 3 # [block_size, kv_heads, head_dim]
|
|
block_shape = first_kv_cache.shape[-block_rank:]
|
|
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
|
block_shape)
|
|
|
|
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
|
self.use_mla, first_kv_cache.shape)
|
|
|
|
self.kv_caches = kv_caches
|
|
self.kv_caches_base_addr = []
|
|
ptrs = []
|
|
lengths = []
|
|
for cache_or_caches in kv_caches.values():
|
|
# Normalize to always be a list of caches
|
|
if self.use_mla:
|
|
for i, cache in enumerate(cache_or_caches, 0):
|
|
base_addr = cache.data_ptr()
|
|
self.kv_caches_base_addr.append(base_addr)
|
|
region_len = self.num_blocks * self.block_len[i % 2]
|
|
ptrs.append(base_addr)
|
|
lengths.append(region_len)
|
|
else:
|
|
cache_list = [cache_or_caches
|
|
] if self.use_mla else cache_or_caches
|
|
for cache in cache_list:
|
|
base_addr = cache.data_ptr()
|
|
self.kv_caches_base_addr.append(base_addr)
|
|
region_len = self.num_blocks * self.block_len[0]
|
|
ptrs.append(base_addr)
|
|
lengths.append(region_len)
|
|
self.m_store.register_buffer(ptrs, lengths)
|
|
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
|
|
self.token_database.set_block_len(self.block_len)
|
|
|
|
if self.use_layerwise:
|
|
self.get_event = threading.Event()
|
|
if self.kv_role in ['kv_producer', 'kv_both']:
|
|
ready_event_sending = threading.Event()
|
|
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
|
self.m_store, self.token_database, self.block_size,
|
|
self.tp_rank, self.dcp_size, self.put_step,
|
|
ready_event_sending, self.num_layers)
|
|
self.kv_send_thread.start()
|
|
ready_event = threading.Event()
|
|
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
|
self.m_store, self.token_database, self.block_size,
|
|
self.tp_rank, self.dcp_size, ready_event, self.get_event)
|
|
self.kv_recv_thread.start()
|
|
ready_event.wait()
|
|
else:
|
|
if self.kv_role in ['kv_producer', 'kv_both']:
|
|
ready_event_sending = threading.Event()
|
|
self.kv_send_thread = KVCacheStoreSendingThread(
|
|
self.m_store, self.token_database, self.block_size,
|
|
self.tp_rank, self.dcp_size, self.put_step,
|
|
ready_event_sending)
|
|
self.kv_send_thread.start()
|
|
if self.load_async:
|
|
ready_event = threading.Event()
|
|
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
|
self.m_store, self.token_database, self.block_size,
|
|
self.tp_rank, self.dcp_size, ready_event)
|
|
self.kv_recv_thread.start()
|
|
ready_event.wait()
|
|
|
|
def start_load_kv(self, metadata: AscendConnectorMetadata):
|
|
self.current_layer = 0
|
|
self.layerwise_retrievers = []
|
|
for request in metadata.requests:
|
|
load_spec = request.load_spec
|
|
if load_spec is None or not load_spec.can_load: #load =0
|
|
continue
|
|
token_len = request.token_len_chunk
|
|
if (load_spec.kvpool_cached_tokens % self.block_size
|
|
!= 0) and (load_spec.kvpool_cached_tokens
|
|
== token_len - 1):
|
|
token_len = request.load_spec.kvpool_cached_tokens + 1
|
|
else:
|
|
token_len = request.load_spec.kvpool_cached_tokens
|
|
request.token_len_chunk = token_len
|
|
if self.use_layerwise:
|
|
layerwise_retriever = self.retrieve_layer(request)
|
|
next(layerwise_retriever) # first layer load
|
|
self.layerwise_retrievers.append(layerwise_retriever)
|
|
else:
|
|
if self.load_async:
|
|
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
|
request, )
|
|
else:
|
|
addr_list = []
|
|
size_list = []
|
|
key_list = []
|
|
mask_num = (request.load_spec.vllm_cached_tokens //
|
|
self.block_size * self.block_size)
|
|
for start, end, key in self.token_database.process_tokens(
|
|
token_len, request.block_hashes, mask_num):
|
|
addr, size, _ = self.token_database.prepare_value(
|
|
start, end, request.block_ids)
|
|
key_list.append(key.to_string())
|
|
addr_list.append(addr)
|
|
size_list.append(size)
|
|
key_list_c = key_list[self.tp_rank % len(
|
|
key_list):] + key_list[:self.tp_rank % len(key_list)]
|
|
addr_list_c = addr_list[self.tp_rank %
|
|
len(addr_list
|
|
):] + addr_list[:self.tp_rank %
|
|
len(addr_list)]
|
|
size_list_c = size_list[self.tp_rank %
|
|
len(size_list
|
|
):] + size_list[:self.tp_rank %
|
|
len(size_list)]
|
|
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
|
|
|
def wait_for_layer_load(self) -> None:
|
|
for layerwise_retriever in self.layerwise_retrievers:
|
|
ret_token_mask = next(layerwise_retriever)
|
|
if self.current_layer == self.num_layers - 1:
|
|
assert ret_token_mask is not None
|
|
num_retrieved_tokens = ret_token_mask.sum().item()
|
|
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
|
|
|
|
def save_kv_layer(self,
|
|
connector_metadata: AscendConnectorMetadata) -> None:
|
|
if self.current_layer == 0:
|
|
self.layerwise_storers = []
|
|
for request in connector_metadata.requests:
|
|
can_save = request.can_save
|
|
if can_save is None or not can_save:
|
|
continue
|
|
|
|
layerwise_storer = self.store_layer(request)
|
|
self.layerwise_storers.append(layerwise_storer)
|
|
for layerwise_storer in self.layerwise_storers:
|
|
try:
|
|
next(layerwise_storer)
|
|
except Exception:
|
|
raise
|
|
self.current_layer = self.current_layer + 1
|
|
|
|
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
|
|
for request in connector_metadata.requests:
|
|
can_save = request.can_save
|
|
if can_save is None or not can_save:
|
|
continue
|
|
|
|
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
|
request, )
|
|
|
|
def retrieve_layer(
|
|
self,
|
|
request: ReqMeta,
|
|
) -> Generator[Optional[torch.Tensor], None, None]:
|
|
"""
|
|
Retrieve the KV cache in a layerwise manner.
|
|
|
|
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
|
|
|
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
|
have the same length as tokens. And the mask should ALWAYS be like
|
|
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
|
|
|
:param **kwargs: The additional arguments for the KV transfer which
|
|
will be passed into the npu_transfer.
|
|
|
|
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
|
be the boolean mask indicating which tokens are retrieved and will
|
|
only be returned in the last iteration.
|
|
"""
|
|
token_len = request.token_len_chunk
|
|
mask_num = (
|
|
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
|
// self.block_size * self.block_size)
|
|
num_required_tokens = token_len - mask_num
|
|
|
|
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
|
|
|
starts = []
|
|
ends = []
|
|
keys = []
|
|
first_flag = True
|
|
for start, end, key in self.token_database.process_tokens(
|
|
token_len, request.block_hashes, mask_num):
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
starts.append(start)
|
|
ends.append(end)
|
|
keys.append(keys_multi_layer)
|
|
ret_mask[start:end] = True
|
|
|
|
if keys:
|
|
# Transpose the keys into layer major format
|
|
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
|
if not first_flag:
|
|
is_finish = self.get_event.wait(timeout=3) #try---cache
|
|
if not is_finish:
|
|
logger.info("Layerwise get failed")
|
|
self.get_event.clear()
|
|
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
|
keys_multi_chunk, starts,
|
|
ends, request.block_ids,
|
|
layer_id)
|
|
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
|
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
|
first_flag = False
|
|
yield None
|
|
else:
|
|
# If no cache are found, we still need to yield to avoid
|
|
# `StopIteration`
|
|
for layer_id in range(self.num_layers):
|
|
yield None
|
|
|
|
retrieved_tokens = torch.sum(ret_mask)
|
|
logger.debug(f"Retrieved {retrieved_tokens} "
|
|
f"out of {num_required_tokens} "
|
|
f"out of total {token_len} tokens")
|
|
|
|
yield ret_mask
|
|
|
|
def store_layer(
|
|
self,
|
|
request: ReqMeta,
|
|
) -> Generator[None, None, None]:
|
|
"""
|
|
Store the KV cache in a layerwise manner.
|
|
|
|
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
|
|
|
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
|
have the same length as tokens. And the mask should ALWAYS be like
|
|
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
|
|
|
:param **kwargs: The additional arguments for the storage backend which
|
|
will be passed into the gpu_connector.
|
|
|
|
return: A generator that yields None. In the first iteration, the
|
|
generator allocates the memory objects for all layers and moves
|
|
the KV cache of the first layer from GPU to CPU. In the next
|
|
iterations, it moves the KV cache of layer i from GPU to the memory
|
|
objects (on CPU) and puts the memory objects of layer i-1 to the
|
|
storage backends. In the last iteration, it puts the memory objects
|
|
of the last layer to the storage backends.
|
|
"""
|
|
starts = []
|
|
ends = []
|
|
keys = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
request.token_len_chunk, request.block_hashes):
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
starts.append(start)
|
|
ends.append(end)
|
|
keys.append(keys_multi_layer) #[block_num,layer_num]
|
|
|
|
if keys:
|
|
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
|
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
|
keys_multi_chunk, starts,
|
|
ends, request.block_ids,
|
|
layer_id,
|
|
request.is_last_chunk)
|
|
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
|
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
|
yield
|
|
else:
|
|
for layer_id in range(self.num_layers):
|
|
yield
|
|
|
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
|
done_sending = (
|
|
self.kv_send_thread.
|
|
get_and_clear_finished_requests( # type: ignore[union-attr]
|
|
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
|
|
|
done_recving = (
|
|
self.kv_recv_thread.
|
|
get_and_clear_finished_requests( # type: ignore[union-attr]
|
|
) if self.load_async else set())
|
|
|
|
logger.debug(
|
|
"Number of completed KV cache send requests: %d, receive "
|
|
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
|
self.tp_rank)
|
|
return done_sending, done_recving
|
|
|
|
def lookup(
|
|
self,
|
|
token_len: int,
|
|
block_hashes: list[BlockHash],
|
|
use_layerwise: bool,
|
|
) -> int:
|
|
"""
|
|
Checks the existence of KV cache of the tokens from the cache engine.
|
|
:param tokens: the input tokens, with shape [seq_len]
|
|
:return: An int indicating how many prefix tokens are cached.
|
|
"""
|
|
end = 0
|
|
keys = []
|
|
try:
|
|
starts = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
token_len, block_hashes):
|
|
if use_layerwise:
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
for item in keys_multi_layer:
|
|
keys.append(item.to_string())
|
|
else:
|
|
keys.append(key.to_string())
|
|
starts.append(start)
|
|
|
|
res = self.m_store.exists(keys) # type: ignore[assignment]
|
|
|
|
if use_layerwise:
|
|
res = self.check_all_layers_exists(res, self.num_layers)
|
|
for index, value in enumerate(res): # type: ignore[arg-type]
|
|
if value != 1:
|
|
return starts[index]
|
|
# all tokens where found, return the maximal end
|
|
except Exception as e:
|
|
logger.error(f"Remote connection failed in contains: {e}")
|
|
return start
|
|
return end
|
|
|
|
def lookup_scheduler(
|
|
self,
|
|
token_len: int,
|
|
block_hashes: list[BlockHash],
|
|
use_layerwise: bool,
|
|
) -> int:
|
|
"""
|
|
Checks the existence of KV cache of the tokens from the cache engine.
|
|
:param tokens: the input tokens, with shape [seq_len]
|
|
:return: An int indicating how many prefix tokens are cached.
|
|
"""
|
|
end = 0
|
|
keys = []
|
|
try:
|
|
starts = []
|
|
for start, end, key in self.token_database.process_tokens(
|
|
token_len, block_hashes):
|
|
if use_layerwise:
|
|
keys_multi_layer = key.split_layers(self.num_layers)
|
|
for item in keys_multi_layer:
|
|
keys.append(item.to_string())
|
|
else:
|
|
keys.append(key.to_string())
|
|
starts.append(start)
|
|
|
|
multi_tp_keys = keys[:]
|
|
for i in range(1, min(self.tp_size, self.num_kv_head)):
|
|
for item in keys:
|
|
new_str = item.replace( # type: ignore[attr-defined]
|
|
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
|
multi_tp_keys.append(new_str)
|
|
|
|
for i in range(1, self.pp_size):
|
|
for item in keys:
|
|
new_str = item.replace( # type: ignore[attr-defined]
|
|
"@pp_rank:0", f"@pp_rank:{i}", 1)
|
|
multi_tp_keys.append(new_str)
|
|
|
|
res = self.m_store.exists(
|
|
multi_tp_keys) # type: ignore[assignment]
|
|
num_block = len(keys)
|
|
if use_layerwise:
|
|
res = self.check_all_layers_exists(res, self.num_layers)
|
|
num_block = len(keys) // self.num_layers
|
|
multi_tp_values = [
|
|
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
|
|
for i in range(
|
|
min(self.tp_size, self.num_kv_head) * self.pp_size)
|
|
]
|
|
index = self.find_min_first_non_one_index(multi_tp_values)
|
|
if index != -1:
|
|
return starts[index]
|
|
# all tokens where found, return the maximal end
|
|
except Exception as e:
|
|
logger.error(f"Remote connection failed in contains: {e}")
|
|
return start
|
|
return end
|
|
|
|
def check_all_layers_exists(self, res: list[int],
|
|
num_layers: int) -> list[int]:
|
|
total_chunks = len(res) // num_layers
|
|
result = []
|
|
|
|
for chunk_idx in range(total_chunks):
|
|
start = chunk_idx * num_layers
|
|
end = start + num_layers
|
|
chunk = res[start:end]
|
|
result.append(1 if all(x == 1 for x in chunk) else 0)
|
|
|
|
return result
|
|
|
|
def find_min_first_non_one_index(self, arr):
|
|
try:
|
|
return min(idx for row in arr for idx, val in enumerate(row)
|
|
if val != 1)
|
|
except ValueError:
|
|
return -1
|