[Model] Support DeepSeek-V4
This commit is contained in:
20
vllm_mlu/distributed/kv_transfer/kv_connector/factory.py
Normal file
20
vllm_mlu/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
|
||||
|
||||
MLUKVConnectors: dict[str, tuple[str, str]] = {
|
||||
"MLUSharedStorageConnector": (
|
||||
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
||||
"SharedStorageConnector"
|
||||
),
|
||||
"MLUNixlConnector": (
|
||||
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"MLUNixlConnector"
|
||||
),
|
||||
}
|
||||
|
||||
for name, (module_path, class_name) in MLUKVConnectors.items():
|
||||
if name not in KVConnectorFactory._registry:
|
||||
KVConnectorFactory.register_connector(name, module_path, class_name)
|
||||
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
class LMCacheConnectorV1_MluHijack(LMCacheConnectorV1):
|
||||
|
||||
def response_remote_alloc_once(self) -> None:
|
||||
self._lmcache_engine.response_remote_alloc_once()
|
||||
|
||||
def request_remote_memory_send(self) -> None:
|
||||
self._lmcache_engine.request_remote_memory_send()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LMCacheConnectorV1,
|
||||
"response_remote_alloc_once",
|
||||
LMCacheConnectorV1_MluHijack.response_remote_alloc_once)
|
||||
MluHijackObject.apply_hijack(LMCacheConnectorV1,
|
||||
"request_remote_memory_send",
|
||||
LMCacheConnectorV1_MluHijack.request_remote_memory_send)
|
||||
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
EngineId, NixlConnectorWorker, NixlAgentMetadata, NixlConnectorScheduler, NixlConnector)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
logger.info("NIXL is available")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
|
||||
class MLUNixlConnector(NixlConnector):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super(NixlConnector, self).__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler : MLUNixlConnectorScheduler | None = (
|
||||
MLUNixlConnectorScheduler(vllm_config, self.engine_id)
|
||||
)
|
||||
self.connector_worker: MLUNixlConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MLUNixlConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
|
||||
class MLUNixlConnectorScheduler(NixlConnectorScheduler):
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: kv transfer info
|
||||
'''
|
||||
if request.kv_transfer_params.get("do_remote_prefill", False):
|
||||
logger.info(f"NIXLConnector update_state_after_alloc: request_id={request.request_id}, "
|
||||
f"num_prompt_tokens={request.num_prompt_tokens}, "
|
||||
f"num_external_tokens={num_external_tokens}, "
|
||||
f"kv_transfer_params={request.kv_transfer_params}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"NIXLConnector update_state_after_alloc: "
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
self._reqs_in_batch.add(request.request_id)
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
|
||||
# save all blocks
|
||||
block_ids = blocks.get_block_ids()[0]
|
||||
# TODO: skip the blocks that are already in the host xfer buffer.
|
||||
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
||||
# kv blocks, so host blocks won't be flushed as long as its device
|
||||
# block is not overwritten; and it will be safe to skip saving them
|
||||
# to host xfer buffer.
|
||||
if block_ids:
|
||||
self._reqs_need_save[request.request_id] = (request, block_ids)
|
||||
elif params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(
|
||||
p in params
|
||||
for p in ("remote_engine_id", "remote_host", "remote_port")
|
||||
):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids()
|
||||
if num_external_tokens > 0
|
||||
else []
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request,
|
||||
local_block_ids,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
else:
|
||||
assert num_external_tokens == 0
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
|
||||
class MLUNixlConnectorWorker(NixlConnectorWorker):
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: not support kv8
|
||||
'''
|
||||
if not isinstance(first_kv_cache, torch.Tensor):
|
||||
kv_caches = {key: value[0] for key, value in kv_caches.items()}
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: support mla
|
||||
'''
|
||||
use_mla = first_kv_cache.shape[0] == 1
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
assert use_mla == self.use_mla
|
||||
|
||||
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
|
||||
# once it goes live, as a single kv layout is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: support mla
|
||||
'''
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: MLU kv_cache layout is [2 (k and v), num_blocks, kv_heads, block_size, head_dim]
|
||||
'''
|
||||
n_kv_heads, block_size, head_dim = block_shape[-3:]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
|
||||
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
|
||||
self.num_blocks, block_shape, first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
|
||||
# Note(tms): I modified this from the original region setup code.
|
||||
# K and V are now in different regions. Advantage is that we can
|
||||
# elegantly support MLA and any cases where the K and V tensors
|
||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
caches_data.append(
|
||||
(base_addr, region_len, cache.device.index, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(self.kv_caches.keys())
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
tp_size=self.world_size,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait()
|
||||
@@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from vllm_mlu.v1.attention.backends.flash_mla import MLAFlashAttentionCommonMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
mm_hashes: list[str]
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
|
||||
)
|
||||
|
||||
|
||||
class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp")
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the connector metadata is None"
|
||||
)
|
||||
return
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info(
|
||||
"Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping),
|
||||
)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE/MLP etc.
|
||||
kv_cache_attr = getattr(layer, "kv_cache", None)
|
||||
if kv_cache_attr is None:
|
||||
continue
|
||||
|
||||
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
token_ids = request.prompt_token_ids or []
|
||||
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = SharedStorageConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
return self._found_match_for_prompt(
|
||||
list(request.prompt_token_ids or []),
|
||||
[f.identifier for f in request.mm_features],
|
||||
)
|
||||
|
||||
def _found_match_for_prompt(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
mm_hashes: list[str],
|
||||
) -> bool:
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||
mm_hashes,
|
||||
create_folder=False,
|
||||
)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
token_bytes = token_ids.numpy().tobytes()
|
||||
# Add mm_hashes to the bytes being hashed to avoid path traversal and
|
||||
# to create a canonical key.
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(
|
||||
token_ids, mm_hashes=mm_hashes, create_folder=True
|
||||
)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size."""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
Reference in New Issue
Block a user