# SPDX-License-Identifier: Apache-2.0 import os from contextlib import contextmanager from functools import cache from typing import Generator, Optional, Type import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: """ Convert a string backend name to a _Backend enum value. Returns: * _Backend: enum value if backend_name is a valid in-tree type * None: otherwise it's an invalid in-tree type or an out-of-tree platform is loaded. """ assert backend_name is not None return _Backend[backend_name] if backend_name in _Backend.__members__ else \ None def get_env_variable_attn_backend() -> Optional[_Backend]: ''' Get the backend override specified by the vLLM attention backend environment variable, if one is specified. Returns: * _Backend enum value if an override is specified * None otherwise ''' backend_name = os.environ.get(STR_BACKEND_ENV_VAR) return (None if backend_name is None else backend_name_to_enum(backend_name)) # Global state allows a particular choice of backend # to be forced, overriding the logic which auto-selects # a backend based on system & workload configuration # (default behavior if this variable is None) # # THIS SELECTION TAKES PRECEDENCE OVER THE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE forced_attn_backend: Optional[_Backend] = None def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: ''' Force all attention operations to use a specified backend. Passing `None` for the argument re-enables automatic backend selection., Arguments: * attn_backend: backend selection (None to revert to auto) ''' global forced_attn_backend forced_attn_backend = attn_backend def get_global_forced_attn_backend() -> Optional[_Backend]: ''' Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. ''' return forced_attn_backend def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, is_blocksparse: bool = False, use_mla: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong # value to be returned from the cache if the value changes between calls. # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the # private function. return _cached_get_attn_backend( head_size=head_size, dtype=dtype, kv_cache_dtype=kv_cache_dtype, block_size=block_size, is_attention_free=is_attention_free, is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, ) @cache def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, ) -> Type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION if is_attention_free: from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionBackend) return PlaceholderAttentionBackend # Check whether a particular choice of backend was # previously forced. # # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None backend_by_global_setting: Optional[_Backend] = ( get_global_forced_attn_backend()) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: # Check the environment variable and override if specified backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") return resolve_obj_by_qualname(attention_cls) @contextmanager def global_force_attn_backend_context_manager( attn_backend: _Backend) -> Generator[None, None, None]: ''' Globally force a vLLM attention backend override within a context manager, reverting the global attention backend override to its prior state upon exiting the context manager. Arguments: * attn_backend: attention backend to force Returns: * Generator ''' # Save the current state of the global backend override (if any) original_value = get_global_forced_attn_backend() # Globally force the new backend override global_force_attn_backend(attn_backend) # Yield control back to the enclosed code block try: yield finally: # Revert the original global backend override, if any global_force_attn_backend(original_value)