# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ KV cache helper for store. """ from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, cast import torch from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase logger = init_logger(__name__) EngineId = str def get_kv_connector_cache_layout(): # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # used for faster transfer. vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config if kv_config is not None: connector_cls = KVConnectorFactory.get_connector_class(kv_config) required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config) if required_kvcache_layout is not None: return required_kvcache_layout logger.info_once( "Connectors do not specify a kv cache layout, defaulting to NHD." ) return "NHD" class KVOutputAggregator: """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" def __init__(self, expected_finished_count: int): # Complete transfer tracker. Used to track finished requests # [req_id -> n_remaining_workers] self._recv_remaining_count = dict[str, int]() self._send_remaining_count = dict[str, int]() self._expected_finished_count = expected_finished_count @classmethod def from_connector(cls, connector: "KVConnectorBase", world_size: int): return cls(connector.get_finished_count() or world_size) def aggregate( self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0 ) -> ModelRunnerOutput | None: if not outputs[output_rank]: return None # Aggregate kv_connector_output from all workers def update_finished_set( req_ids: set[str] | None, remaining_count_dict: dict[str, int], finished_set: set[str], ) -> None: for req_id in req_ids or (): remaining_count = remaining_count_dict.get( req_id, self._expected_finished_count ) remaining_count_dict[req_id] = remaining_count - 1 if remaining_count_dict[req_id] == 0: finished_set.add(req_id) del remaining_count_dict[req_id] finished_sending = set[str]() finished_recving = set[str]() aggregated_kv_connector_stats = None combined_kv_cache_events = None invalid_block_ids = set[int]() for model_runner_output in outputs: assert model_runner_output is not None kv_output = model_runner_output.kv_connector_output if not kv_output: continue # Allow the worker to dynamically update the expected number of # finished sending/recving for new requests. if ( kv_output.expected_finished_count > 0 and kv_output.expected_finished_count != self._expected_finished_count ): logger.debug( "Expected finished requests updated from %d to %d", self._expected_finished_count, kv_output.expected_finished_count, ) self._expected_finished_count = kv_output.expected_finished_count update_finished_set( kv_output.finished_sending, self._send_remaining_count, finished_sending ) update_finished_set( kv_output.finished_recving, self._recv_remaining_count, finished_recving ) # Aggregate kv_connector_stats from all workers. if aggregated_kv_connector_stats is None: # Use the first worker's kv_connector_stats as accumulator. aggregated_kv_connector_stats = kv_output.kv_connector_stats elif kv_connector_stats := kv_output.kv_connector_stats: if aggregated_kv_connector_stats is None: aggregated_kv_connector_stats = kv_connector_stats else: assert isinstance( aggregated_kv_connector_stats, type(kv_connector_stats) ) aggregated_kv_connector_stats = ( aggregated_kv_connector_stats.aggregate(kv_connector_stats) ) # Combine kv_cache_events from all workers. if combined_kv_cache_events is None: # Use the first worker's kv_cache events as start event list. combined_kv_cache_events = kv_output.kv_cache_events elif kv_cache_events := kv_output.kv_cache_events: assert isinstance( combined_kv_cache_events, type(kv_cache_events), ) worker_kv_cache_events = kv_cache_events.get_all_events() combined_kv_cache_events.add_events(worker_kv_cache_events) combined_kv_cache_events.increment_workers(1) invalid_block_ids |= kv_output.invalid_block_ids # select output of the worker specified by output_rank output = outputs[output_rank] assert output is not None output.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending or None, finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, kv_cache_events=combined_kv_cache_events or None, invalid_block_ids=invalid_block_ids, expected_finished_count=self._expected_finished_count, ) return output def _make_src_and_dst_indices( src_block_ids: list[int], dst_block_ids: list[int], src_device: torch.device | str, dst_device: torch.device | str, ) -> tuple[torch.Tensor, torch.Tensor]: src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64) dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64) return src_indices, dst_indices def copy_kv_blocks( src_kv_caches: dict[str, torch.Tensor], dst_kv_caches: dict[str, torch.Tensor], src_block_ids: list[int], dst_block_ids: list[int], direction: Literal["h2d", "d2h"], ) -> None: """Copy kv blocks between different buffers.""" if ( not src_kv_caches or not dst_kv_caches or not src_block_ids or not dst_block_ids or len(src_block_ids) != len(dst_block_ids) ): return src_device = next(iter(src_kv_caches.values())).device dst_device = next(iter(dst_kv_caches.values())).device src_indices, dst_indices = _make_src_and_dst_indices( src_block_ids=src_block_ids, dst_block_ids=dst_block_ids, src_device=src_device, dst_device=dst_device, ) if direction == "h2d": copy_fn = current_platform.insert_blocks_to_device else: copy_fn = current_platform.swap_out_blocks_to_host for layer_name in src_kv_caches: src_tensor = src_kv_caches[layer_name] dst_tensor = dst_kv_caches[layer_name] copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio): """ Transforms the layout of received KV cache blocks to the local block_size. (Only works for local blocksize > remote blocksize) example: local blocksize = 16 tokens, remote blocksize = 4 tokens local block[0] = remote block[0, 1, 2, 3] remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... local is |h0-b0..................|h1-b0..................|... permute is to: 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) 2. permute => (H, nblocks, remoteN, D) 3. flatten => (H, localN, D) """ blocks_to_update = cache.index_select(0, indices) # use physical order blocks_to_update = blocks_to_update.permute(0, 2, 1, 3) n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] remote_block_size = block_size // block_size_ratio n_blocks = block_size_ratio permuted_blocks = ( blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size) .permute(0, 2, 1, 3, 4) .flatten(2, 3) ) permuted_blocks = permuted_blocks.permute(0, 2, 1, 3) cache.index_copy_(0, indices, permuted_blocks) def kv_postprocess_layout_on_receive(cache, indices): """Transforms the layout of received KV cache blocks to the local format. This method corrects layout mismatches from direct memory copies by permuting the tensor dimensions. - **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]` - **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]` Implementation: - x = blocks_to_update.reshape(src_shape) # view local kv with sender layout - permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back """ blocks_to_update = cache.index_select(0, indices) target_shape = list(blocks_to_update.shape) target_shape[0] = -1 inv_order = [0, 2, 1, 3] src_shape = tuple(target_shape[i] for i in inv_order) blocks_to_update = cache.index_select(0, indices) permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order) cache.index_copy_(0, indices, permuted_blocks) def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio): """ Transforms the layout of received KV cache to the local block_size and HND. (Only works for local blocksize > remote blocksize) prefill is HND, smaller block_size decode(local) is NHD, larger block_size """ blocks_to_update = cache.index_select(0, indices) block_size, n_kv_heads, head_size = blocks_to_update.shape[1:] remote_block_size = block_size // block_size_ratio n_blocks = block_size_ratio permuted_blocks = ( blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size) .permute(0, 1, 3, 2, 4) .flatten(1, 2) ) cache.index_copy_(0, indices, permuted_blocks) def yield_req_data( scheduler_output, ) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: """ Yields: (req_id, new_block_id_groups, preempted) """ # new requests for req_data in scheduler_output.scheduled_new_reqs: yield req_data.req_id, req_data.block_ids, False # cached requests cached_reqs = scheduler_output.scheduled_cached_reqs yield from zip( cached_reqs.req_ids, cached_reqs.new_block_ids, (req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids), ) @dataclass class TpKVTopology: """ Helper class for tensor parallel and KV topology information for mapping between local and remote TP workers. """ tp_rank: int remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int attn_backend: type[AttentionBackend] engine_id: EngineId remote_block_size: dict[EngineId, int] tensor_shape: torch.Size | None = None def __post_init__(self): # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. _MOCK_BLOCK_SIZE = 16 kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 ) logger.debug("Test kv_cache_shape: %s", kv_cache_shape) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below. self._is_kv_layout_blocks_first = ( len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) self._cross_layers_blocks = False if self.tensor_shape is not None: self._cross_layers_blocks = ( len(self.tensor_shape) == len(kv_cache_shape) + 1 ) if self._cross_layers_blocks: logger.debug("Using cross-layer KV cache") # prepend layers dimension _MOCK_NUM_LAYERS = 80 kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape try: kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=self._cross_layers_blocks ) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(self.tensor_shape))) # In case of cross layers permute kv_cache_shape according to # stride_order to retrieve physical position of block_size kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) # In the default non-cross layers layout the block_size position # is logical while in the cross layers case it is the physical # position. This matches the shape of the actual kv cache tensors # passed at register_kv_caches()/register_cross_layers_kv_cache() block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) assert block_size_position is not None self._block_size_position = -(len(kv_cache_shape) - block_size_position) @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). return not ( self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first ) @property def tp_size(self) -> int: return self.remote_tp_size[self.engine_id] @property def block_size(self) -> int: return self.remote_block_size[self.engine_id] @property def cross_layers_blocks(self) -> bool: return self._cross_layers_blocks @property def block_size_position(self) -> int: return self._block_size_position def tp_ratio( self, remote_tp_size: int, ) -> int: """ Calculate the tensor parallel ratio between local and remote TP. We can think of it as the number of local TP workers-per-remote TP workers. Local workers will read from the same remote TP worker in groups of size `tp_ratio`.If remote tp_size > local tp_size, the ratio is flipped (remote_size/local_size) and the returned value is negative. """ if self.tp_size >= remote_tp_size: assert self.tp_size % remote_tp_size == 0, ( f"Local tensor parallel size {self.tp_size} is not divisible " f"by remote tensor parallel size {remote_tp_size}." ) return self.tp_size // remote_tp_size assert remote_tp_size % self.tp_size == 0, ( f"Remote tensor parallel size {remote_tp_size} is not divisible " f"by local tensor parallel size {self.tp_size}." ) # P TP > D TP case, return the ratio as negative return remote_tp_size // self.tp_size def pp_ratio( self, remote_pp_size: int, ) -> int: """ Calculate the pipeline parallel ratio between local and remote PP. """ assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, ( f"Local pipline parallel size {self.tp_size} is not divisible " f"by remote pipline parallel size {remote_pp_size} or vice versa." ) return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size def block_size_ratio( self, remote_block_size: int, ) -> int: """ Calculate the block size ratio between local and remote TP. """ assert self.block_size % remote_block_size == 0, ( f"Local block size {self.block_size} is not divisible " f"by remote block size {remote_block_size} or vice versa." ) return self.block_size // remote_block_size def tp_ratio_from_engine_id( self, remote_engine_id: EngineId, ) -> int: remote_tp_size = self.remote_tp_size[remote_engine_id] return self.tp_ratio(remote_tp_size) def block_size_ratio_from_engine_id( self, remote_engine_id: EngineId, ) -> int: remote_block_size = self.remote_block_size[remote_engine_id] return self.block_size_ratio(remote_block_size) def is_kv_replicated(self, engine_id: EngineId) -> bool: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. """ tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: # MLA is always replicated as the hidden dim can't be split. return self.is_mla or self.is_kv_replicated(remote_engine_id) def get_target_remote_ranks( self, remote_tp_size: int, remote_pp_size: int ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank (on D) will read from. When remote tp_size > local tp_size, we read from multiple remote ranks. """ tp_ratio = self.tp_ratio(remote_tp_size) pp_ratio = self.pp_ratio(remote_pp_size) target_pp_rank_list = [] target_tp_rank_list = [] if self.pp_size < remote_pp_size: for i in range(pp_ratio): target_pp_rank_list.append(self.pp_rank * pp_ratio + i) else: target_pp_rank_list.append(self.pp_rank // pp_ratio) if self.tp_size < remote_tp_size: for i in range(tp_ratio): target_tp_rank_list.append(self.tp_rank * tp_ratio + i) else: target_tp_rank_list.append(self.tp_rank // tp_ratio) target_rank_list = [] for pp_rank in target_pp_rank_list: for tp_rank in target_tp_rank_list: target_rank = pp_rank * remote_tp_size + tp_rank target_rank_list.append((target_rank, pp_rank, tp_rank)) return target_rank_list def get_target_remote_ranks_from_engine_id( self, remote_engine_id: EngineId, ) -> list[int]: remote_tp_size = self.remote_tp_size[remote_engine_id] remote_pp_size = self.remote_pp_size[remote_engine_id] return self.get_target_remote_ranks(remote_tp_size, remote_pp_size) def get_current_attn_backend(vllm_config: VllmConfig): layer_type = cast(type[Any], AttentionLayerBase) layers = get_layers_from_vllm_config(vllm_config, layer_type, None) if layers: backend = next(iter(layers.values())).get_attn_backend() else: # Fallback for tests, when static_forward_context is empty. logger.debug( "No layers found in the vLLM config. " "Falling back to default attention backend." ) from vllm.v1.attention.selector import get_attn_backend backend = get_attn_backend( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, kv_cache_dtype=vllm_config.cache_config.cache_dtype, block_size=vllm_config.cache_config.block_size, use_mla=vllm_config.model_config.use_mla, ) return backend