Sync from v0.13
This commit is contained in:
0
vllm/attention/utils/__init__.py
Normal file
0
vllm/attention/utils/__init__.py
Normal file
118
vllm/attention/utils/fa_utils.py
Normal file
118
vllm/attention/utils/fa_utils.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Rocm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
) from e
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.attention_config.flash_attn_version is not None:
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform "
|
||||
"defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
fa_version,
|
||||
fa_version_unsupported_reason(fa_version),
|
||||
)
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return (
|
||||
get_flash_attn_version() == 3
|
||||
and current_platform.get_device_capability().major == 9
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
return (
|
||||
is_fa_version_supported(3)
|
||||
and current_platform.get_device_capability()[0] == 9
|
||||
)
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def validate_kv_sharing_target(
|
||||
current_layer_name, target_layer_name, static_forward_context
|
||||
):
|
||||
error_msg = (
|
||||
f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} "
|
||||
)
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg + "cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg + f"must be the same type as the current layer ({expected})."
|
||||
)
|
||||
60
vllm/attention/utils/kv_transfer_utils.py
Normal file
60
vllm/attention/utils/kv_transfer_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
|
||||
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
"""Decorator that handles KV layer transfer prior and after execution of
|
||||
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||
|
||||
On entry: waits for the KV layer from the connector.
|
||||
On exit: saves the KV layer to the connector.
|
||||
"""
|
||||
# Import at runtime to avoid circular dependency
|
||||
from vllm.attention.layer import get_attention_context
|
||||
|
||||
# Inspect the signature ONCE when the decorator is applied.
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Find the index of 'layer_name' parameter.
|
||||
try:
|
||||
layer_name_index = param_names.index("layer_name")
|
||||
except ValueError as e:
|
||||
raise TypeError(
|
||||
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||
) from e
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Wait for KV layer on entry
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Save KV cache layer on exit
|
||||
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user