update
This commit is contained in:
195
vllm_old/attention/backends/registry.py
Normal file
195
vllm_old/attention/backends/registry.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend registry"""
|
||||
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _AttentionBackendEnumMeta(enum.EnumMeta):
|
||||
"""Metaclass for AttentionBackendEnum to provide better error messages."""
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
"""Get backend by name with helpful error messages."""
|
||||
try:
|
||||
return super().__getitem__(name)
|
||||
except KeyError:
|
||||
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
|
||||
valid_backends = ", ".join(m.name for m in members)
|
||||
raise ValueError(
|
||||
f"Unknown attention backend: '{name}'. "
|
||||
f"Valid options are: {valid_backends}"
|
||||
) from None
|
||||
|
||||
|
||||
class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||
ROCM_AITER_FA = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
TORCH_SDPA = "" # this tag is only used for ViT
|
||||
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
FLASHMLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||
)
|
||||
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
||||
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
|
||||
ROCM_AITER_UNIFIED_ATTN = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
|
||||
"RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
CUSTOM = ""
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||
|
||||
|
||||
def register_backend(
|
||||
backend: AttentionBackendEnum, class_path: str | None = None
|
||||
) -> Callable[[type], type]:
|
||||
"""Register or override a backend implementation.
|
||||
|
||||
Args:
|
||||
backend: The AttentionBackendEnum member to register
|
||||
class_path: Optional class path. If not provided and used as
|
||||
decorator, will be auto-generated from the class.
|
||||
|
||||
Returns:
|
||||
Decorator function if class_path is None, otherwise a no-op
|
||||
|
||||
Examples:
|
||||
# Override an existing backend
|
||||
@register_backend(AttentionBackendEnum.FLASH_ATTN)
|
||||
class MyCustomFlashAttn:
|
||||
...
|
||||
|
||||
# Register a custom third-party backend
|
||||
@register_backend(AttentionBackendEnum.CUSTOM)
|
||||
class MyCustomBackend:
|
||||
...
|
||||
|
||||
# Direct registration
|
||||
register_backend(
|
||||
AttentionBackendEnum.CUSTOM,
|
||||
"my.module.MyCustomBackend"
|
||||
)
|
||||
"""
|
||||
|
||||
def decorator(cls: type) -> type:
|
||||
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
return cls
|
||||
|
||||
if class_path is not None:
|
||||
_OVERRIDES[backend] = class_path
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Backwards compatibility alias for plugins
|
||||
class _BackendMeta(type):
|
||||
"""Metaclass to provide deprecation warnings when accessing _Backend."""
|
||||
|
||||
def __getattribute__(cls, name: str):
|
||||
if name not in ("__class__", "__mro__", "__name__"):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return getattr(AttentionBackendEnum, name)
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return AttentionBackendEnum[name]
|
||||
|
||||
|
||||
class _Backend(metaclass=_BackendMeta):
|
||||
"""Deprecated: Use AttentionBackendEnum instead.
|
||||
|
||||
This class is provided for backwards compatibility with plugins
|
||||
and will be removed in a future release.
|
||||
"""
|
||||
|
||||
pass
|
||||
Reference in New Issue
Block a user