# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass from typing import ClassVar import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionLayer, AttentionType, is_quantized_kv_cache, ) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec logger = init_logger(__name__) _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM) class CPUAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, torch.float32, ] @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "CPU_ATTN" @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """CPU attention supports decoder, encoder-only and encoder-decoder attention.""" return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, AttentionType.ENCODER_ONLY, AttentionType.ENCODER_DECODER, ) @staticmethod def get_impl_cls() -> type["CPUAttentionBackendImpl"]: return CPUAttentionBackendImpl @staticmethod def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]: return CPUAttentionMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return 2, num_blocks, num_kv_heads, block_size, head_size @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @dataclass class CPUAttentionMetadata: isa: str num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor scheduler_metadata: torch.Tensor | None causal: bool = True # can be removed after deprecate sdpa use_sdpa_prefill: bool = False num_decode_tokens: int = 0 sdpa_attn_masks: list[torch.Tensor | None] | None = None sdpa_start_loc: torch.Tensor | None = None class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ) -> None: super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.use_sdpa_prefill = False reorder_batch_threshold = None if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH: # in this case, decode seqs are reordered to the front of prefill seqs # to split decode and prefill. Then use SDPA for prefill and # cpu_attention_with_kv_cache for decode reorder_batch_threshold = 1 self.use_sdpa_prefill = True self._init_reorder_batch_threshold(reorder_batch_threshold, False) self.kv_cache_spec = kv_cache_spec self.vllm_config = vllm_config parallel_config = vllm_config.parallel_config self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config) self.num_heads = vllm_config.model_config.get_num_attention_heads( parallel_config ) self.head_dim = kv_cache_spec.head_size self.dtype = vllm_config.model_config.dtype self.window_size = getattr(kv_cache_spec, "sliding_window", -1) if self.window_size is None: self.window_size = -1 self.block_size = vllm_config.cache_config.block_size self.isa = _get_attn_isa(self.dtype, self.block_size) self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> CPUAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping causal = False if self.is_cross_attention else common_attn_metadata.causal sdpa_start_loc = query_start_loc num_decode_tokens = 0 if self.use_sdpa_prefill and causal: # Decoder, need reorder and truncate assert self.reorder_batch_threshold (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, require_uniform=True, ) ) num_reqs = num_decodes sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens seq_lens = seq_lens[:num_decodes] query_start_loc = query_start_loc[: num_decodes + 1] block_table_tensor = block_table_tensor[:num_decodes] sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( num_reqs=num_reqs, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, seq_lens=seq_lens, dtype=self.dtype, query_start_loc=query_start_loc, causal=causal, sliding_window_size=self.window_size, isa=self.isa, enable_kv_split=True, ) attn_metadata = CPUAttentionMetadata( isa=self.isa, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, scheduler_metadata=sheduler_metadata, causal=causal, use_sdpa_prefill=self.use_sdpa_prefill, num_decode_tokens=num_decode_tokens, sdpa_start_loc=sdpa_start_loc, ) return attn_metadata class CPUAttentionBackendImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, sinks: torch.Tensor | None = None, ) -> None: self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) if logits_soft_cap is not None and attn_type in ( AttentionType.ENCODER, AttentionType.ENCODER_ONLY, ): logger.warning_once( "CPU_ATTN does not support logits softcap for" " ENCODER and ENCODER_ONLY, outputs may be slightly off" ) if logits_soft_cap is None: logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) elif attn_type == AttentionType.ENCODER_ONLY: self.sliding_window = (sliding_window - 1, sliding_window - 1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype self.num_queries_per_kv = self.num_heads // self.num_kv_heads if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN") self.attn_type = attn_type self.sinks = sinks if self.sinks is not None: assert self.sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " "heads in the layer" ) def forward( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: CPUAttentionMetadata | None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass for CPU attention backend. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = [2, num_blocks, num_kv_heads, block_size, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for CPUAttentionBackendImpl" ) # For warming-up if attn_metadata is None: return output num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, return self._run_sdpa_forward( query[:num_actual_tokens], key[:num_actual_tokens], value[:num_actual_tokens], output[:num_actual_tokens], attn_metadata, self.attn_type, ) # For decoder and cross-attention, use KV cache, size are # [num_blocks, num_kv_heads, block_size, head_size] key_cache, value_cache = kv_cache.unbind(0) # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. if ( self.kv_sharing_target_layer_name is None and key is not None and value is not None ): ops.cpu_attn_reshape_and_cache( key, value, key_cache, value_cache, attn_metadata.slot_mapping, attn_metadata.isa, ) if attn_metadata.use_sdpa_prefill: assert self.sinks is None, "Attention sink is unsupported in SDPA prefill" num_decode_tokens = attn_metadata.num_decode_tokens self._run_sdpa_forward( query[num_decode_tokens:num_actual_tokens], key[num_decode_tokens:num_actual_tokens], value[num_decode_tokens:num_actual_tokens], output[num_decode_tokens:num_actual_tokens], attn_metadata, self.attn_type, ) num_actual_tokens = num_decode_tokens if num_actual_tokens > 0: ops.cpu_attention_with_kv_cache( query=query[:num_actual_tokens], key_cache=key_cache, value_cache=value_cache, output=output[:num_actual_tokens], # type: ignore query_start_loc=attn_metadata.query_start_loc, seq_lens=attn_metadata.seq_lens, scale=self.scale, causal=attn_metadata.causal, alibi_slopes=self.alibi_slopes, # type: ignore sliding_window=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, scheduler_metadata=attn_metadata.scheduler_metadata, s_aux=self.sinks, ) return output def _run_sdpa_forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, attn_metadata: CPUAttentionMetadata, attn_type: str, ) -> torch.Tensor: attn_masks = attn_metadata.sdpa_attn_masks if attn_masks is None: if self.alibi_slopes is not None: attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, attn_metadata.sdpa_start_loc, ) elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1: assert attn_metadata.seq_lens is not None attn_masks = _make_sliding_window_bias( attn_metadata.sdpa_start_loc, self.sliding_window[0], self.sliding_window[1], query.dtype, ) else: attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore attn_metadata.sdpa_attn_masks = attn_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) causal_attn = attn_type == AttentionType.DECODER sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore for i in range(len(attn_masks)): mask = attn_masks[i] start_q = sdpa_start_loc[i] end_q = sdpa_start_loc[i + 1] sub_out = ( torch.nn.functional.scaled_dot_product_attention( query[None, :, start_q:end_q, :], key[None, :, start_q:end_q, :], value[None, :, start_q:end_q, :], attn_mask=mask, dropout_p=0.0, is_causal=causal_attn and mask is None, scale=self.scale, ) .squeeze(0) .movedim(query.dim() - 2, 0) ) output[start_q:end_q, :, :] = sub_out return output def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, sdpa_start_loc: torch.Tensor, ) -> list[torch.Tensor]: attn_biases: list[torch.Tensor] = [] seq_num = sdpa_start_loc.size(0) - 1 sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore for i in range(seq_num): seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i] bias = torch.arange(seq_len, dtype=dtype) # type: ignore # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) inf_mask = ( torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore .fill_(-torch.inf) .triu_(diagonal=1) ) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases def _make_sliding_window_bias( sdpa_start_loc: torch.Tensor, left_window_size: int, right_window_size: int, dtype: torch.dtype, ) -> list[torch.Tensor]: attn_biases: list[torch.Tensor] = [] seq_num = sdpa_start_loc.size(0) - 1 sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore for i in range(seq_num): seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i] mask = torch.full( # type: ignore (1, seq_len, seq_len), # type: ignore fill_value=1, dtype=dtype, ) if right_window_size != -1: mask = torch.tril(mask, diagonal=right_window_size) if left_window_size != -1: mask = torch.triu(mask, diagonal=-left_window_size) mask = torch.log(mask) attn_biases.append(mask) return attn_biases def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str: supports_amx = torch._C._cpu._is_amx_tile_supported() if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: return "amx" elif block_size % 32 == 0: if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: return "neon" else: return "vec" else: return "vec16"