# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Optional import os import numpy as np import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) # from vllm_vacc.vllm.attention.backends.vacc_attn import (VACCAttentionBackendImpl, # VACCAttentionMetadata) # from vllm_vacc.vllm.attention.backends.vacc_attn import VACCAttentionBackendImpl # from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_input_batch import InputBatch from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, seq_lens: list[int], ) -> list[torch.Tensor]: attn_biases: list[torch.Tensor] = [] for seq_len in seq_lens: bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) inf_mask = torch.empty( (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases def _make_sliding_window_bias( seq_lens: list[int], window_size: Optional[int], dtype: torch.dtype, ) -> list[torch.Tensor]: attn_biases: list[torch.Tensor] = [] for seq_len in seq_lens: tensor = torch.full( (1, seq_len, seq_len), dtype=dtype, fill_value=1, ) shift = 0 mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore if window_size is not None: mask = torch.triu(mask, diagonal=shift - window_size + 1) mask = torch.log(mask) attn_biases.append(mask.to(dtype)) return attn_biases @dataclass class VACCAttentionMetadata(AttentionMetadata): """Metadata for VACCAttentionMetadata. """ # Total number of prefill requests. num_prefills: int # Number of prefill tokens. num_prefill_tokens: int # Number of decode tokens. Note that it is equivalent to the number of # decode requests. num_decode_tokens: int # (num_tokens,). The indices of the token slots that input tokens will be # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] # Maximum sequence length in the batch. 0 if it is prefill-only batch. max_decode_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. chunked_prefill: bool seq_lens: Optional[list[int]] = None # For non-chunked prefill # For chunked prefill only max_query_len: Optional[int] = None max_kv_len: Optional[int] = None prefill_query_start_loc: Optional[torch.Tensor] = None kv_start_loc: Optional[torch.Tensor] = None prefill_block_tables: Optional[torch.Tensor] = None # For V1 logits index only query_start_loc: Optional[torch.Tensor] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation encoder_seq_lens: Optional[list[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None # Number of tokens input to encoder num_encoder_tokens: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt # when alibi slopes is used. It is because of the limitation # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[list[torch.Tensor]] = None self.encoder_attn_bias: Optional[list[torch.Tensor]] = None self.cross_attn_bias: Optional[list[torch.Tensor]] = None @property def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' return ((self.encoder_seq_lens is not None) and (self.encoder_seq_lens_tensor is not None) and (self.max_encoder_seq_len is not None)) @property def is_all_cross_attn_metadata_set(self): ''' All attention metadata required for enc/dec cross-attention is set. Superset of encoder attention required metadata. ''' return (self.is_all_encoder_attn_metadata_set and (self.cross_slot_mapping is not None) and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]: # Currently chunked prefill is not supported if self.num_prefill_tokens == 0: return None return self @property def decode_metadata(self) -> Optional["VACCAttentionMetadata"]: # Currently chunked prefill is not supported if self.num_decode_tokens == 0: return None return self def get_seq_lens( self, attn_type: AttentionType, ): ''' Extract appropriate sequence lengths from attention metadata according to attention type. Arguments: * attn_metadata: Attention metadata structure associated with attention * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention Returns: * Appropriate sequence lengths tensor for query * Appropriate sequence lengths tensor for key & value ''' if (attn_type == AttentionType.DECODER or attn_type == AttentionType.ENCODER_ONLY): seq_lens_q = self.seq_lens seq_lens_kv = self.seq_lens elif attn_type == AttentionType.ENCODER: seq_lens_q = self.encoder_seq_lens seq_lens_kv = self.encoder_seq_lens elif attn_type == AttentionType.ENCODER_DECODER: seq_lens_q = self.seq_lens seq_lens_kv = self.encoder_seq_lens else: raise AttributeError(f"Invalid attention type {str(attn_type)}") return seq_lens_q, seq_lens_kv def get_attn_bias( self, attn_type: AttentionType, ) -> Optional[list[torch.Tensor]]: ''' Extract appropriate attention bias from attention metadata according to attention type. Arguments: * attn_metadata: Attention metadata structure associated with attention * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention Returns: * Appropriate attention bias value given the attention type ''' if (attn_type == AttentionType.DECODER or attn_type == AttentionType.ENCODER_ONLY): return self.attn_bias elif attn_type == AttentionType.ENCODER: return self.encoder_attn_bias elif attn_type == AttentionType.ENCODER_DECODER: return self.cross_attn_bias else: raise AttributeError(f"Invalid attention type {str(attn_type)}") def set_attn_bias( self, attn_bias: list[torch.Tensor], attn_type: AttentionType, ) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. Arguments: * attn_metadata: Attention metadata structure associated with attention * attn_bias: The desired attention bias value * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention ''' if (attn_type == AttentionType.DECODER or attn_type == AttentionType.ENCODER_ONLY): self.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: self.encoder_attn_bias = attn_bias elif attn_type == AttentionType.ENCODER_DECODER: self.cross_attn_bias = attn_bias else: raise AttributeError(f"Invalid attention type {str(attn_type)}") def get_seq_len_block_table_args( self, attn_type: str, ) -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation. Decoder attn -> select entirely decoder self-attention-related fields Encoder/decoder cross-attn -> select encoder sequence lengths & cross-attn block-tables fields Encoder attn -> select encoder sequence lengths fields & no block tables Arguments: * attn_metadata: Attention metadata structure associated with attention * is_prompt: True if prefill, False otherwise * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention Returns: * Appropriate sequence-lengths tensor * Appropriate max sequence-length scalar * Appropriate block tables (or None) ''' if (attn_type == AttentionType.DECODER or attn_type == AttentionType.ENCODER_ONLY): # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run return (self.seq_lens_tensor, self.max_decode_seq_len, self.block_tables) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, self.cross_block_tables) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") # class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]): # def __init__(self, input_builder: ModelInputForVACCBuilder) -> None: # self.chunked_prefill = input_builder.chunked_prefill # self.input_builder = input_builder # def prepare(self): # self.input_data = self.input_builder.input_data # def build(self, seq_lens: list[int], query_lens: list[int], # cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata: # input_data = self.input_data # prefill_seq_lens = seq_lens[0:input_data.num_prefills] # prefill_query_lens = query_lens[0:input_data.num_prefills] # slot_mapping = torch.tensor(input_data.slot_mapping, # dtype=torch.int32, # device=self.input_builder.device) # # For chunked-prefill # if self.chunked_prefill and input_data.num_prefill_tokens != 0: # prefill_block_tables = make_tensor_with_pad( # self.input_data.prefill_block_tables, # pad=0, # dtype=torch.int32, # device=self.input_builder.device, # ) # query_lens_tensor = torch.tensor(prefill_query_lens, # dtype=torch.int32, # device=self.input_builder.device) # kv_lens_tensor = torch.tensor(prefill_seq_lens, # dtype=torch.int32, # device=self.input_builder.device) # query_start_loc = torch.zeros(input_data.num_prefills + 1, # dtype=torch.int32, # device=self.input_builder.device) # kv_start_loc = torch.zeros(input_data.num_prefills + 1, # dtype=torch.int32, # device=self.input_builder.device) # torch.cumsum(query_lens_tensor, # dim=0, # dtype=torch.int32, # out=query_start_loc[1:]) # torch.cumsum(kv_lens_tensor, # dim=0, # dtype=torch.int32, # out=kv_start_loc[1:]) # max_query_len = max(prefill_query_lens) # max_kv_len = max(prefill_seq_lens) # else: # prefill_block_tables = None # query_start_loc = None # kv_start_loc = None # max_query_len = None # max_kv_len = None # # For paged attention # if input_data.num_decode_tokens != 0: # seq_lens_tensor = torch.tensor( # input_data.seq_lens[input_data.num_prefills:], # dtype=torch.int32, # device=self.input_builder.device, # ) # block_tables = make_tensor_with_pad( # self.input_data.decode_block_tables, # pad=0, # dtype=torch.int32, # device=self.input_builder.device, # ) # # lowest_dim_size = block_tables.size(-1) # # if lowest_dim_size < 1024: # # padding_amount = 1024 - lowest_dim_size # # padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device) # # block_tables = torch.cat((block_tables, padding), dim=-1) # else: # block_tables = torch.tensor([]) # seq_lens_tensor = torch.tensor( # input_data.seq_lens[:input_data.num_prefills], # dtype=torch.int32, # device=self.input_builder.device, # ) # # For multi-modal models # placeholder_index_maps = None # if len(input_data.multi_modal_inputs_list) != 0: # placeholder_index_maps = { # modality: placeholder_map.index_map() # for modality, placeholder_map in # input_data.multi_modal_placeholder_maps.items() # } # attn_metadata = VACCAttentionMetadata( # chunked_prefill=self.chunked_prefill, # seq_lens=seq_lens, #prefill_seq_lens, # seq_lens_tensor=seq_lens_tensor, # max_query_len=max_query_len, # max_kv_len=max_kv_len, # query_start_loc=query_start_loc, # kv_start_loc=kv_start_loc, # max_decode_seq_len=None, # num_prefills=input_data.num_prefills, # num_prefill_tokens=input_data.num_prefill_tokens, # num_decode_tokens=input_data.num_decode_tokens, # block_tables=block_tables, # prefill_block_tables=prefill_block_tables, # slot_mapping=slot_mapping, # multi_modal_placeholder_index_maps=placeholder_index_maps, # enable_kv_scales_calculation=False, # ) # return attn_metadata def fp32_attention( query_layer, key_layer, value_layer, mask, norm_factor, out_type=None, ): ori_type = out_type if out_type is not None else query_layer.dtype query_layer = query_layer.to(torch.float32) key_layer = key_layer.to(torch.float32) value_layer = value_layer.to(torch.float32) # GQA if query_layer.size(1) != key_layer.size(1): if query_layer.size(1) % key_layer.size(1) != 0: assert False groups = query_layer.size(1) // key_layer.size(1) key_layer = torch.repeat_interleave(key_layer, groups, dim=1) value_layer = torch.repeat_interleave(value_layer, groups, dim=1) matmul_result = torch.bmm( query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] ) * norm_factor mask_output = matmul_result if mask != None: mask = mask if mask.dim() >= 3 else mask.unsqueeze(0) mask_output = matmul_result.masked_fill_(mask, -10000.0) # [b * np, sq, sk] probs = torch.nn.Softmax(dim=-1)(mask_output) context_layer = torch.bmm(probs, value_layer.transpose(0, 1)) return context_layer.transpose(0, 1).to(ori_type) class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, ) -> None: # if logits_soft_cap is not None: # logger.warning_once("Torch SPDA does not support logits soft cap. " # "Outputs may be slightly off.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) if kv_cache_dtype != "auto": raise NotImplementedError( "Torch SDPA backend does not support FP8 KV cache. " "Please use xFormers backend instead.") self.attn_type = attn_type def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: VACCAttentionMetadata, # type: ignore output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " "encoder metadata attributes.") elif (attn_type == AttentionType.ENCODER_DECODER and (not attn_metadata.is_all_cross_attn_metadata_set)): raise AttributeError("Encoder/decoder cross-attention " "requires setting cross-attention " "metadata attributes.") # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) else: assert value is None if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. # # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if (key is not None) and (value is not None): if attn_type == AttentionType.ENCODER_DECODER: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, # preventing this IF-statement branch from running updated_slot_mapping = attn_metadata.cross_slot_mapping else: # Update self-attention KV cache (prefill/decode) updated_slot_mapping = attn_metadata.slot_mapping PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, updated_slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. # Encoder/decoder cross-attention requires no chunked # prefill (100% prefill or 100% decode tokens, no mix) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens else: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 if attn_type == AttentionType.DECODER: # Only enforce this shape-constraint for decoder # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens output = torch.empty_like(query) if prefill_meta := attn_metadata.prefill_metadata: assert attn_metadata.seq_lens is not None if (kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0): self._run_vacc_forward( output, query, key, value, prefill_meta, attn_type=attn_type) else: # prefix-enabled attention assert not self.need_mask import intel_extension_for_pytorch.llm.modules as ipex_modules output = torch.empty_like(query) ipex_modules.PagedAttention.flash_attn_varlen_func( output[:prefill_meta.num_prefill_tokens, :, :], query[:prefill_meta.num_prefill_tokens, :, :], key_cache, value_cache, prefill_meta.query_start_loc, prefill_meta.kv_start_loc, prefill_meta.max_query_len, prefill_meta.max_kv_len, self.scale, True, prefill_meta.prefill_block_tables, self.alibi_slopes, ) if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata.") # Decoding run. # ( # seq_lens_arg, # max_seq_len_arg, # block_tables_arg, # ) = decode_meta.get_seq_len_block_table_args(attn_type) # Note: # decode attention still use SDPA method # reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size) k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3]) v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3]) block_per_group = env_blk_grp_size // 16 # convert block_tables to 8K group index block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32) attn_outs = [] for i in range(len(decode_meta.seq_lens_tensor)): seq_len = decode_meta.seq_lens_tensor[i] k_slices = k_cache[block_tables[i], ...] k = \ torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len] v_slices = v_cache[block_tables[i], ...] v = \ torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len] q = query[i : i + 1, ...] if q.dtype == torch.bfloat16: attn_out = fp32_attention( q.cpu(), k.cpu(), v.cpu(), None, self.scale ).to(query.dtype).to(query.device) else: attn_out = torch.vacc.scaled_dot_product_attention( query=q, key=k, value=v, attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False, sm_scale=self.scale, ) attn_outs.append(attn_out) output = torch.cat(attn_outs, dim=0) # ''' # PagedAttention.forward_decode( # output[attn_metadata.num_prefill_tokens:, :, :], # query[attn_metadata.num_prefill_tokens:, :, :], # key_cache, # value_cache, # block_tables_arg, # seq_lens_arg, # max_seq_len_arg, # self.kv_cache_dtype, # self.num_kv_heads, # self.scale, # self.alibi_slopes, # layer._k_scale, # layer._v_scale, # ) # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) def _run_vacc_forward( self, output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: VACCAttentionMetadata, attn_type: AttentionType = AttentionType.DECODER, ): # if self.num_kv_heads != self.num_heads: # key = key.repeat_interleave(self.num_queries_per_kv, dim=1) # value = value.repeat_interleave(self.num_queries_per_kv, dim=1) attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: assert attn_metadata.seq_lens is not None attn_masks = _make_sliding_window_bias( attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: seq_lens, _ = attn_metadata.get_seq_lens(attn_type) attn_masks = [None] * len(seq_lens) attn_metadata.set_attn_bias(attn_masks, attn_type) causal_attn = (attn_type == AttentionType.DECODER) seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) start_q, start_kv = 0, 0 for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv sub_out=torch.vacc.scaled_dot_product_attention( query[start_q:end_q,:, :].to(torch.float16) * (self.scale), key[start_kv:end_kv,:, :].to(torch.float16), value[start_kv:end_kv,:, :].contiguous().to(torch.float16), attn_mask=None, dropout_p=0.0, is_causal=True if attn_type == AttentionType.DECODER else False, #causal_attn and not self.need_mask, is_train=False, recompute=False, flash_attention=False, sm_scale=1) output[ start_q:end_q,:, :] = sub_out start_q, start_kv = end_q, end_kv return output class VACCAttentionBackend(AttentionBackend): accept_output_buffer: bool = False @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() if head_size not in supported_head_sizes: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes.") @staticmethod def get_name() -> str: return "TORCH_SDPA_VLLM_V1" @staticmethod def get_impl_cls() -> type["VACCAttentionBackendImpl"]: return VACCAttentionBackendImpl @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return VACCAttentionMetadata # @staticmethod # def get_state_cls() -> type["CommonAttentionState"]: # return CommonAttentionState @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_builder_cls() -> type["VACCAttentionMetadataBuilderV1"]: return VACCAttentionMetadataBuilderV1 @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str:str, ) -> tuple[int, ...]: # return PagedAttention.get_kv_cache_shape(num_blocks, block_size, # num_kv_heads, head_size) if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False class VACCAttentionMetadataBuilderV1(AttentionMetadataBuilder[VACCAttentionMetadata]): # def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, # block_table: BlockTable) -> None: # self.runner = runner # self.block_table = block_table # # For reorder # self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, # dtype=np.int64) # self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs, # dtype=np.int64) # self.num_prompt_req: int = 0 # self.seq_start_loc_cpu = torch.zeros( # runner.max_num_reqs + 1, # dtype=torch.int32, # device="cpu", # ) # self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.scheduler_config = vllm_config.scheduler_config # For reorder self.reorder_prompt_req_index_list = np.empty( vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) self.reorder_decode_req_index_list = np.empty( vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) self.num_prompt_req: int = 0 self.seq_start_loc_cpu = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, device="cpu", ) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() # def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # vllm_config: VllmConfig, device: torch.device): # super().__init__(kv_cache_spec, layer_names, vllm_config, device) # self.block_size = kv_cache_spec.block_size # model_config = vllm_config.model_config # self.num_heads_q = model_config.get_num_attention_heads( # vllm_config.parallel_config) # self.num_heads_kv = model_config.get_num_kv_heads( # vllm_config.parallel_config) # self.headdim = model_config.get_head_size() def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: prompt_list_idx = 0 decode_list_idx = 0 for req_index in range(input_batch.num_reqs): if input_batch.num_computed_tokens_cpu[ req_index] < input_batch.num_prompt_tokens[req_index]: # prompt stage self.reorder_prompt_req_index_list[prompt_list_idx] = req_index prompt_list_idx += 1 else: # decode stage self.reorder_decode_req_index_list[decode_list_idx] = req_index decode_list_idx += 1 assert decode_list_idx + prompt_list_idx == input_batch.num_reqs # Update prompt requests number self.num_prompt_req = prompt_list_idx reorder_req_num = 0 for req_index in range(decode_list_idx): if self.reorder_decode_req_index_list[req_index] < prompt_list_idx: reorder_req_num += 1 else: break if reorder_req_num == 0: return False reorder_prompt_list = ( self.reorder_prompt_req_index_list[:prompt_list_idx] [-reorder_req_num:]) reorder_decode_list = ( self.reorder_decode_req_index_list[:decode_list_idx] [:reorder_req_num]) assert reorder_decode_list.size == reorder_prompt_list.size for idx in range(reorder_req_num): prompt_req_index = reorder_prompt_list[idx].item() decode_req_index = reorder_decode_list[idx].item() input_batch.swap_states(prompt_req_index, decode_req_index) return True def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens = common_attn_metadata.seq_lens # seq_lens = common_attn_metadata.seq_lens # runner = self.runner # block_table = self.block_table # seq_lens = runner.seq_lens[:num_reqs] num_prompt_req = self.num_prompt_req query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - num_prefill_tokens) # print('query_start_loc_cpu', query_start_loc_cpu) # print('num_prompt_req', num_prompt_req) # print('num_reqs', num_reqs) # num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() # num_decode_tokens = runner.query_start_loc_np[num_reqs].item( # ) - num_prefill_tokens # block_table.slot_mapping[:num_actual_tokens].copy_( # block_table.slot_mapping_cpu[:num_actual_tokens], # non_blocking=True) # slot_mapping = block_table.slot_mapping[:num_actual_tokens] #.long() # block_table_tensor = block_table.get_device_tensor() slot_mapping = common_attn_metadata.slot_mapping block_table_tensor = common_attn_metadata.block_table_tensor block_num_per_group = env_blk_grp_size // 16 block_table_tensor_new = block_table_tensor[:num_reqs-num_prompt_req, ::block_num_per_group].contiguous() # [bs, seq//16] => [bs, seq//16//block_num_per_group, block_num_per_group] # => [:num_reqs, :, 0] 提取前reqs行,并且把 block_num_per_group 的倍数提取出 attn_metadata = VACCAttentionMetadata( num_prefills=num_prompt_req, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, slot_mapping=slot_mapping, seq_lens_tensor=seq_lens, # decode max_decode_seq_len=None, # decode block_tables=block_table_tensor_new, # decode chunked_prefill=False, # max_query_len=max_query_len, # max_kv_len=max_prefill_seq_len, # prefill_query_start_loc=runner. # query_start_loc_cpu[:num_prompt_req + 1], # prefill # kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + # 1], # prefill prefill_block_tables=block_table_tensor[: num_prompt_req], # prefill query_start_loc=query_start_loc_cpu[:num_reqs + 1], # for logits index # multi_modal_placeholder_index_maps=None, # enable_kv_scales_calculation=False, ) return attn_metadata