# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from dataclasses import dataclass, replace from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args import numpy as np import torch from typing_extensions import deprecated if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.kv_cache_interface import AttentionSpec class AttentionType(str, Enum): """ Attention type. Use string to be compatible with `torch.compile`. """ DECODER = "decoder" """Decoder attention between previous layer Q/K/V.""" ENCODER = "encoder" """Encoder attention between previous layer Q/K/V for encoder-decoder.""" ENCODER_ONLY = "encoder_only" """Encoder attention between previous layer Q/K/V.""" ENCODER_DECODER = "encoder_decoder" """Attention between dec. Q and enc. K/V for encoder-decoder.""" class MultipleOf: base: int def __init__(self, base: int): self.base = base class AttentionBackend(ABC): """Abstract class for attention backends.""" # For some attention backends, we allocate an output tensor before # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"] # Does attention's forward() include kv cache update? forward_includes_kv_cache_update: bool = True @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(1)] @staticmethod @abstractmethod def get_name() -> str: raise NotImplementedError @staticmethod @abstractmethod def get_impl_cls() -> type["AttentionImplBase"]: raise NotImplementedError @staticmethod @abstractmethod def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: raise NotImplementedError @staticmethod @abstractmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: raise NotImplementedError @staticmethod def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: """ Get the physical (memory layout) ordering of the kv cache dimensions. e.g. if the KV cache shape is [2, num_blocks, block_size, num_heads, head_size], and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical ordering of dimensions is [num_blocks, num_heads, 2, block_size, head_size]. If this function is unimplemented / raises NotImplementedError, the physical layout of the KV cache will match the logical shape. Args: include_num_layers_dimension: if True, includes an additional num_layers dimension, which is assumed to be prepended to the logical KV cache shape. With the above example, a return value (2, 4, 0, 1, 3, 5) corresponds to [num_blocks, num_heads, num_layers, 2, block_size, head_size]. If an additional dimension is NOT included in the returned tuple, the physical layout will not include a layers dimension. Returns: A tuple of ints which is a permutation of range(len(shape)). """ raise NotImplementedError @classmethod def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) @classmethod def get_supported_head_sizes(cls) -> list[int]: return [] @classmethod def supports_head_size(cls, head_size: int) -> bool: supported_head_sizes = cls.get_supported_head_sizes() return (not supported_head_sizes) or head_size in supported_head_sizes @classmethod def supports_dtype(cls, dtype: torch.dtype) -> bool: return dtype in cls.supported_dtypes @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: if kv_cache_dtype is None: return True return (not cls.supported_kv_cache_dtypes) or ( kv_cache_dtype in cls.supported_kv_cache_dtypes ) @classmethod def supports_block_size(cls, block_size: int | None) -> bool: from vllm.config.cache import BlockSize if block_size is None: return True valid_sizes = get_args(BlockSize) if block_size not in valid_sizes: return False supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() if not supported_kernel_block_sizes: return True for supported_size in supported_kernel_block_sizes: if isinstance(supported_size, MultipleOf): supported_size = supported_size.base # With hybrid_blocks feature, the framework-level block size # only needs to be a multiple of the kernel's requirement, # even if the kernel requires a fixed block_size. if block_size % supported_size == 0: return True return False @classmethod def is_mla(cls) -> bool: return False @classmethod def supports_sink(cls) -> bool: return False @classmethod def supports_alibi_sqrt(cls) -> bool: return False @classmethod def supports_mm_prefix(cls) -> bool: return False @classmethod def is_sparse(cls) -> bool: return False @classmethod def supports_per_head_quant_scales(cls) -> bool: return False @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """Check if backend supports a given attention type. By default, only supports decoder attention. Backends should override this to support other attention types. """ return attn_type == AttentionType.DECODER @classmethod def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: return True @classmethod def supports_combination( cls, head_size: int, dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", block_size: int, use_mla: bool, has_sink: bool, use_sparse: bool, device_capability: "DeviceCapability", ) -> str | None: return None @classmethod def validate_configuration( cls, head_size: int, dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", block_size: int, use_mla: bool, has_sink: bool, use_sparse: bool, use_mm_prefix: bool, use_per_head_quant_scales: bool, device_capability: "DeviceCapability", attn_type: str, ) -> list[str]: invalid_reasons = [] if not cls.supports_head_size(head_size): invalid_reasons.append("head_size not supported") if not cls.supports_dtype(dtype): invalid_reasons.append("dtype not supported") if not cls.supports_kv_cache_dtype(kv_cache_dtype): invalid_reasons.append("kv_cache_dtype not supported") if not cls.supports_block_size(block_size): invalid_reasons.append("block_size not supported") if use_mm_prefix and not cls.supports_mm_prefix(): invalid_reasons.append( "partial multimodal token full attention not supported" ) if use_mla != cls.is_mla(): if use_mla: invalid_reasons.append("MLA not supported") else: invalid_reasons.append("non-MLA not supported") if has_sink and not cls.supports_sink(): invalid_reasons.append("sink setting not supported") if use_sparse != cls.is_sparse(): if use_sparse: invalid_reasons.append("sparse not supported") else: invalid_reasons.append("non-sparse not supported") if use_per_head_quant_scales and not cls.supports_per_head_quant_scales(): invalid_reasons.append("per-head quant scales not supported") if not cls.supports_compute_capability(device_capability): invalid_reasons.append("compute capability not supported") if not cls.supports_attn_type(attn_type): invalid_reasons.append(f"attention type {attn_type} not supported") combination_reason = cls.supports_combination( head_size, dtype, kv_cache_dtype, block_size, use_mla, has_sink, use_sparse, device_capability, ) if combination_reason is not None: invalid_reasons.append(combination_reason) return invalid_reasons @classmethod def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": return None class AttentionMetadata: pass T = TypeVar("T", bound=AttentionMetadata) @dataclass class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. For many of the tensors we keep both GPU and CPU versions. """ query_start_loc: torch.Tensor query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens: torch.Tensor """(batch_size,), the number of computed tokens for each request""" num_reqs: int """Number of requests""" # TODO(lucas): rename to num_tokens since it may be padded and this is misleading num_actual_tokens: int """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" max_seq_len: int """Longest context length (may be an upper bound)""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor causal: bool = True # Needed by FastPrefillAttentionBuilder logits_indices_padded: torch.Tensor | None = None num_logits_indices: int | None = None # Needed by CrossAttentionBuilder encoder_seq_lens: torch.Tensor | None = None encoder_seq_lens_cpu: np.ndarray | None = None dcp_local_seq_lens: torch.Tensor | None = None dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) _seq_lens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None _num_computed_tokens_cache: torch.Tensor | None = None def batch_size(self) -> int: return self.seq_lens.shape[0] def naive_query_lens(self) -> torch.Tensor: """Naive because it assumes that query ends where the next query starts.""" return self.query_start_loc[1:] - self.query_start_loc[:-1] def replace(self, **kwargs) -> "CommonAttentionMetadata": return replace(self, **kwargs) @property @deprecated( """ Prefer using device seq_lens directly to avoid implicit H<>D sync. If a CPU copy is needed, use `seq_lens.cpu()` instead. Will be removed in a future release, please migrate as soon as possible. """ ) def seq_lens_cpu(self) -> torch.Tensor: if self._seq_lens_cpu is None: self._seq_lens_cpu = self.seq_lens.to("cpu") return self._seq_lens_cpu @property @deprecated( """ Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full async scheduling. If a CPU copy is needed, it can be derived from query_start_loc_cpu and seq_lens. Will be removed in a future release, please migrate as soon as possible. """ ) def num_computed_tokens_cpu(self) -> torch.Tensor: if self._num_computed_tokens_cpu is None: query_seq_lens = ( self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] ) self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens return self._num_computed_tokens_cpu def compute_num_computed_tokens(self) -> torch.Tensor: """Compute num_computed_tokens on device (seq_lens - query_lens).""" if self._num_computed_tokens_cache is None: query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] self._num_computed_tokens_cache = self.seq_lens - query_lens return self._num_computed_tokens_cache # TODO(lucas): remove once we have FULL-CG spec-decode support def unpadded( self, num_actual_tokens: int, num_actual_reqs: int ) -> "CommonAttentionMetadata": maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None return CommonAttentionMetadata( query_start_loc=self.query_start_loc[: num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs] if self._seq_lens_cpu is not None else None, _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] if self._num_computed_tokens_cpu is not None else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, max_seq_len=self.max_seq_len, block_table_tensor=self.block_table_tensor[:num_actual_reqs], slot_mapping=self.slot_mapping[:num_actual_tokens], causal=self.causal, logits_indices_padded=self.logits_indices_padded, num_logits_indices=self.num_logits_indices, encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens), encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu), dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens), dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu), ) M = TypeVar("M") class AttentionCGSupport(Enum): """Constants for the cudagraph support of the attention backend Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" ALWAYS = 3 """Cudagraph always supported; supports mixed-prefill-decode""" UNIFORM_BATCH = 2 """Cudagraph supported for batches the only contain query lengths that are the same, this can be used for spec-decode i.e. "decodes" are 1 + num_speculative_tokens""" UNIFORM_SINGLE_TOKEN_DECODE = 1 """Cudagraph supported for batches the only contain query_len==1 decodes""" NEVER = 0 """NO cudagraph support""" class AttentionMetadataBuilder(ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). # Do not access directly. Call get_cudagraph_support() instead. _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: int | None = None # Does this backend/builder support updating the block table in existing # metadata supports_update_block_table: bool = False @abstractmethod def __init__( self, kv_cache_spec: "AttentionSpec", layer_names: list[str], vllm_config: "VllmConfig", device: torch.device, ): self.kv_cache_spec = kv_cache_spec self.layer_names = layer_names self.vllm_config = vllm_config self.device = device @classmethod def get_cudagraph_support( cls: type["AttentionMetadataBuilder"], vllm_config: "VllmConfig", kv_cache_spec: "AttentionSpec", ) -> AttentionCGSupport: """Get the cudagraph support level of this builder class.""" return cls._cudagraph_support def _init_reorder_batch_threshold( self, reorder_batch_threshold: int | None = 1, supports_spec_as_decode: bool = False, supports_dcp_with_varlen: bool = False, ) -> None: self.reorder_batch_threshold = reorder_batch_threshold if self.reorder_batch_threshold is not None and supports_spec_as_decode: # If the backend supports spec-as-decode kernels, then we can set # the reorder_batch_threshold based on the number of speculative # tokens from the config. speculative_config = self.vllm_config.speculative_config if ( speculative_config is not None and speculative_config.num_speculative_tokens is not None ): max_num_queries_for_spec = ( 1 + (2 if speculative_config.parallel_drafting else 1) * speculative_config.num_speculative_tokens ) self.reorder_batch_threshold = max( self.reorder_batch_threshold, max_num_queries_for_spec, ) if ( self.vllm_config.parallel_config.decode_context_parallel_size > 1 and not supports_dcp_with_varlen ): self.reorder_batch_threshold = 1 @abstractmethod def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. Args: common_prefix_len: The length of the common prefix of the batch. common_attn_metadata: The common attention metadata. fast_build: The meta-data will prioritize speed of building over then speed at execution. Can be used for spec-decode where the result of a build call may only be used for few layers/iters. """ raise NotImplementedError def update_block_table( self, metadata: M, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: """ Update the block table for the attention metadata. Faster when theres multiple kv-cache groups that create virtually the same metadata but just with different block tables. Only needs to be implemented if supports_update_block_table is True. """ raise NotImplementedError def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ return self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata ) def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, draft_index: int, ) -> M: """ Build attention metadata for draft model. Uses build by default. Args: common_attn_metadata: The common attention metadata. draft_index: The index of the current draft operation. When speculating a chain of tokens, this index refers to the draft attempt for the i-th token. For tree-based attention, this index instead refers to the draft attempt for the i-th level in the tree of tokens. """ return self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, fast_build=True, ) def use_cascade_attention( self, common_prefix_len: int, query_lens: np.ndarray, num_query_heads: int, num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, use_local_attention: bool, num_sms: int, dcp_world_size: int, ) -> bool: return False class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: ... class AttentionImplBase(ABC, Generic[T]): """Base class for attention implementations. Contains common attributes and initialization logic shared by both standard AttentionImpl and MLAAttentionImpl. Does not define a forward method - subclasses define their own forward interfaces. """ # Required attributes that all impls should have num_heads: int head_size: int scale: float # Whether the attention impl can return the softmax lse for decode. # Some features like decode context parallelism require the softmax lse. can_return_lse_for_decode: bool = False # Whether the attention impl supports Prefill Context Parallelism. supports_pcp: bool = False # Whether the attention impl(or ops) supports MTP # when cp_kv_cache_interleave_size > 1 supports_mtp_with_cp_non_trivial_interleave_size: bool = False # some attention backends might not always want to return lse # even if they can return lse (for efficiency reasons) need_to_return_lse_for_decode: bool = False # Whether this attention implementation supports pre-quantized query input. # When True, the attention layer will quantize queries before passing them # to this backend, allowing torch.compile to fuse the quantization with # previous operations. This is typically supported when using FP8 KV cache # with compatible attention kernels (e.g., TRT-LLM). # Subclasses should set this in __init__. # TODO add support to more backends: # https://github.com/vllm-project/vllm/issues/25584 supports_quant_query_input: bool = False dcp_world_size: int dcp_rank: int pcp_world_size: int pcp_rank: int total_cp_world_size: int total_cp_rank: int def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) try: from vllm.distributed.parallel_state import get_dcp_group self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 try: from vllm.distributed.parallel_state import get_pcp_group self.pcp_world_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group().rank_in_group except AssertionError: self.pcp_world_size = 1 self.pcp_rank = 0 self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank self.need_to_return_lse_for_decode = ( self.dcp_world_size > 1 and self.can_return_lse_for_decode ) return self def process_weights_after_loading(self, act_dtype: torch.dtype): pass class AttentionImpl(AttentionImplBase[T], Generic[T]): """Standard attention implementation with forward method.""" @abstractmethod def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, sliding_window: int | None = None, kv_cache_dtype: str = "auto", logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, ) -> None: raise NotImplementedError @abstractmethod def forward( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError def fused_output_quant_supported(self, quant_key: "QuantKey"): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False def fused_rope_kvcache_supported(self): """ Does this attention implementation support RoPE+KVCache fusion. This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops with the KV cache update for implementations that support it. """ return False def do_rope_and_kv_cache_update( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, ): """ If `fused_rope_kvcache_supported` returns True, this method will be called by torch.ops.vllm.fused_rope_and_unified_kv_cache_update to perform the inplace RoPE and KV cache update. """ raise NotImplementedError class MLAAttentionImpl(AttentionImplBase[T], Generic[T]): """MLA attention implementation with forward_mqa and forward_mha methods.""" @abstractmethod def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, # MLA Specific Arguments q_lora_rank: int | None, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, kv_b_proj: "ColumnParallelLinear", indexer: object | None = None, q_pad_num_heads: int | None = None, ) -> None: raise NotImplementedError @abstractmethod def forward_mha( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, k_scale: torch.Tensor, output: torch.Tensor, ) -> None: """MHA-style prefill forward pass.""" raise NotImplementedError @abstractmethod def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, layer: AttentionLayer, ) -> tuple[torch.Tensor, torch.Tensor | None]: """MQA-style decode forward pass.""" raise NotImplementedError class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): """Sparse MLA attention implementation with only forward_mqa method. Sparse MLA implementations only support decode (MQA-style) attention. They do not support prefill (MHA-style) attention. """ @abstractmethod def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, # MLA Specific Arguments q_lora_rank: int | None, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, kv_b_proj: "ColumnParallelLinear", indexer: object | None = None, q_pad_num_heads: int | None = None, ) -> None: raise NotImplementedError @abstractmethod def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, layer: AttentionLayer, ) -> tuple[torch.Tensor, torch.Tensor | None]: """MQA-style decode forward pass.""" raise NotImplementedError def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return kv_cache_dtype.startswith("fp8") def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], builder_cls: type[AttentionMetadataBuilder[M]], ) -> type[AttentionBackend]: """ Return a new subclass where `get_builder_cls` returns `builder_cls`. """ name: str = name_prefix + attention_backend_cls.__name__ # type: ignore return type( name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} ) def subclass_attention_backend_with_overrides( name_prefix: str, attention_backend_cls: type[AttentionBackend], overrides: dict[str, Any], ) -> type[AttentionBackend]: name: str = name_prefix + attention_backend_cls.__name__ # type: ignore return type(name, (attention_backend_cls,), overrides)