v1.0
This commit is contained in:
268
distributed/kv_transfer/kv_connector/utils.py
Normal file
268
distributed/kv_transfer/kv_connector/utils.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class model_aware_kv_ops_helper:
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
def get_model_args(self, model_executable: torch.nn.Module):
|
||||
model_config = model_executable.model.config
|
||||
self.model_executable = model_executable
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/v1/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim", None)
|
||||
if head_size is None:
|
||||
head_size = int(hidden_size // num_attention_heads)
|
||||
|
||||
return num_heads, head_size
|
||||
|
||||
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
def put_kv_to_cache(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
keys,
|
||||
values,
|
||||
layer,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
start_pos,
|
||||
end_pos,
|
||||
):
|
||||
model_config = model_executable.model.config
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
k_c_normed_k_pe = keys.squeeze(1)
|
||||
k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank]
|
||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :]
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed.to(kv_cache.device),
|
||||
k_pe.to(kv_cache.device),
|
||||
kv_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys.to(key_cache.device),
|
||||
values.to(value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
# used for faster transfer.
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once(
|
||||
"Connectors do not specify a kv cache layout, defaulting to NHD."
|
||||
)
|
||||
return "NHD"
|
||||
|
||||
|
||||
class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, expected_finished_count: int):
|
||||
# Complete transfer tracker. Used to track finished requests
|
||||
# [req_id -> n_remaining_workers]
|
||||
self._recv_remaining_count = dict[str, int]()
|
||||
self._send_remaining_count = dict[str, int]()
|
||||
self._expected_finished_count = expected_finished_count
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
|
||||
return cls(connector.get_finished_count() or world_size)
|
||||
|
||||
def aggregate(
|
||||
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
|
||||
) -> ModelRunnerOutput | None:
|
||||
if not outputs[output_rank]:
|
||||
return None
|
||||
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(
|
||||
req_ids: set[str] | None,
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str],
|
||||
) -> None:
|
||||
for req_id in req_ids or ():
|
||||
remaining_count = remaining_count_dict.get(
|
||||
req_id, self._expected_finished_count
|
||||
)
|
||||
remaining_count_dict[req_id] = remaining_count - 1
|
||||
if remaining_count_dict[req_id] == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
# Allow the worker to dynamically update the expected number of
|
||||
# finished sending/recving for new requests.
|
||||
if (
|
||||
kv_output.expected_finished_count > 0
|
||||
and kv_output.expected_finished_count != self._expected_finished_count
|
||||
):
|
||||
logger.debug(
|
||||
"Expected finished requests updated from %d to %d",
|
||||
self._expected_finished_count,
|
||||
kv_output.expected_finished_count,
|
||||
)
|
||||
self._expected_finished_count = kv_output.expected_finished_count
|
||||
|
||||
update_finished_set(
|
||||
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||
)
|
||||
update_finished_set(
|
||||
kv_output.finished_recving, self._recv_remaining_count, finished_recving
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(
|
||||
aggregated_kv_connector_stats, type(kv_connector_stats)
|
||||
)
|
||||
aggregated_kv_connector_stats = (
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
assert output is not None
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: torch.device | str,
|
||||
dst_device: torch.device | str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if (
|
||||
not src_kv_caches
|
||||
or not dst_kv_caches
|
||||
or not src_block_ids
|
||||
or not dst_block_ids
|
||||
or len(src_block_ids) != len(dst_block_ids)
|
||||
):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device,
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if direction == "h2d":
|
||||
copy_fn = current_platform.insert_blocks_to_device
|
||||
else:
|
||||
copy_fn = current_platform.swap_out_blocks_to_host
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
Reference in New Issue
Block a user