61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import inspect
|
|
from collections.abc import Callable
|
|
from functools import wraps
|
|
|
|
from vllm.distributed.kv_transfer import (
|
|
get_kv_transfer_group,
|
|
has_kv_transfer_group,
|
|
is_v1_kv_transfer_group,
|
|
)
|
|
|
|
|
|
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
|
"""Decorator that handles KV layer transfer prior and after execution of
|
|
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
|
|
|
On entry: waits for the KV layer from the connector.
|
|
On exit: saves the KV layer to the connector.
|
|
"""
|
|
# Import at runtime to avoid circular dependency
|
|
from vllm.attention.layer import get_attention_context
|
|
|
|
# Inspect the signature ONCE when the decorator is applied.
|
|
sig = inspect.signature(func)
|
|
param_names = list(sig.parameters.keys())
|
|
|
|
# Find the index of 'layer_name' parameter.
|
|
try:
|
|
layer_name_index = param_names.index("layer_name")
|
|
except ValueError as e:
|
|
raise TypeError(
|
|
f"Function {func.__name__} must have a 'layer_name' parameter"
|
|
) from e
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
|
return func(*args, **kwargs)
|
|
|
|
layer_name: str = args[layer_name_index]
|
|
|
|
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
|
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
|
connector = get_kv_transfer_group()
|
|
if attn_metadata is None or not connector.has_connector_metadata():
|
|
return func(*args, **kwargs)
|
|
|
|
# Wait for KV layer on entry
|
|
connector.wait_for_layer_load(layer_name)
|
|
|
|
# Execute the function
|
|
result = func(*args, **kwargs)
|
|
|
|
# Save KV cache layer on exit
|
|
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
|
|
|
return result
|
|
|
|
return wrapper
|