from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from vllm.multimodal import MultiModalPlaceholderMap try: from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeMlaWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonMetadata from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) #from vllm.attention.ops.paged_attn import PagedAttention from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention # from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.utils import async_tensor_h2d, make_tensor_with_pad # import time, os if TYPE_CHECKING: from vllm_vacc.vllm.worker.vacc_model_runner import (ModelInputForVACCBuilder, ModelInputForVACCWithSamplingMetadata) class VACCMLABackend(AttentionBackend): @staticmethod def get_name() -> str: return "TORCH_VACC" @staticmethod def get_impl_cls() -> Type["VACCMLAImpl"]: return VACCMLAImpl @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: return VACCMLAMetadata @staticmethod def get_builder_cls() -> Type["VACCMLAMetadataBuilder"]: return VACCMLAMetadataBuilder @staticmethod def get_state_cls() -> Type["VACCMLAState"]: return VACCMLAState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, ) -> Tuple[int, ...]: return (num_blocks, block_size, head_size) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @staticmethod def get_supported_head_sizes() -> List[int]: return [576] class VACCMLAState(AttentionState): def __init__(self, runner): self.runner = runner self._is_graph_capturing = False @contextmanager def graph_capture(self, max_batch_size: int): self._is_graph_capturing = True self._graph_slot_mapping = torch.full((max_batch_size, ), PAD_SLOT_ID, dtype=torch.long, device=self.runner.device) self._graph_seq_lens = torch.ones(max_batch_size, dtype=torch.int32, device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) self._positions = torch.zeros((max_batch_size, ), dtype=torch.long, device=self.runner.device) yield self._is_graph_capturing = False del self._graph_slot_mapping del self._graph_seq_lens del self._graph_block_tables del self._positions def graph_clone(self, batch_size: int): assert self._is_graph_capturing return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], # max_query_len=1, # max_decode_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, input_positions=self._positions[:batch_size], head_dim=self.runner.model_config.get_head_size()) if is_encoder_decoder_model: raise NotImplementedError( "VACCMLAState does not support encoder/decoder yet") return attn_metadata def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): input_buffers = { "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, "input_positions": attn_metadata.decode_metadata.input_positions, } if is_encoder_decoder_model: raise NotImplementedError( "VACCMLAState does not support encoder/decoder yet") return input_buffers def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): input_positions = attn_metadata.input_positions num_positions = input_positions.shape[0] input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # CUDA graph buffer is padded so only perform a partial copy based on # num_positions input_buffers["input_positions"][:num_positions].copy_( input_positions, non_blocking=True) if is_encoder_decoder_model: raise NotImplementedError( "VACCMLAState does not support encoder/decoder yet") def begin_forward(self, model_input): return @dataclass class VACCMLAMetadata(MLACommonMetadata): """Metadata for VACCMLAMetadata. NOTE: Any python object stored here is not updated when it is cuda-graph replayed. If you have values that need to be changed dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool # Maximum query length in the batch. max_query_len: Optional[int] = None input_positions: Optional[torch.Tensor] = None # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["VACCMLAMetadata"] = None _cached_decode_metadata: Optional["VACCMLAMetadata"] = None num_prefill_tokens: int num_kv_splits: int = 4 # TODO(lucas) add heuristic attn_logits: Optional[torch.Tensor] = None req_idx: Optional[torch.Tensor] = None # The dimension of the attention heads head_dim: Optional[int] = None def __post_init__(self): supported_head_sizes = VACCMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ not in supported_head_sizes: raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") @property def prefill_metadata(self) -> Optional["VACCMLAMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else self.query_start_loc[:self.num_prefills + 1]) slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:self.num_prefill_tokens]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills]) seq_start_loc = (None if self.seq_start_loc is None else self.seq_start_loc[:self.num_prefills + 1]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) input_positions = (None if self.input_positions is None else self.input_positions[:self.num_prefill_tokens]) self._cached_prefill_metadata = VACCMLAMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=None, max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, head_dim=self.head_dim) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["VACCMLAMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: return self._cached_decode_metadata assert self.seq_lens_tensor is not None # Compute some attn_metadata fields which default to None slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[self.num_prefill_tokens:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) input_positions = (None if self.input_positions is None else self.input_positions[self.num_prefill_tokens:]) self._cached_decode_metadata = VACCMLAMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=self.seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, # Batch may be composed of prefill|decodes, adjust query start # indices to refer to the start of decodes. E.g. # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. query_start_loc=(self.query_start_loc[self.num_prefills:] - self.query_start_loc[self.num_prefills]) if self.query_start_loc is not None else None, seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, input_positions=input_positions, head_dim=self.head_dim) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForVACCWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], block_size: int, num_seqs: int, num_queries: int, turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries if num_seqs != num_queries: assert num_seqs > num_queries assert self.use_cuda_graph if turn_prefills_into_decodes: # When Mutli-Step is enabled with Chunked-Prefill, prefills and # decodes are scheduled together. In the first step, all the # prefills turn into decodes. This update reflects that # conversion. assert self.num_decode_tokens + self.num_prefills == num_seqs self.num_decode_tokens += self.num_prefills self.num_prefills = 0 # self.num_prefill_tokens = 0 # self.max_prefill_seq_len = 0 self.max_query_len = 1 self.slot_mapping = self.slot_mapping[:num_seqs] else: assert self.seq_lens is not None assert self.max_decode_seq_len == max(self.seq_lens) assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs assert self.slot_mapping.shape == (num_seqs, ) assert self.seq_lens is not None assert len(self.seq_lens) == num_seqs assert self.seq_lens_tensor is not None assert self.seq_lens_tensor.shape == (num_seqs, ) # assert self.max_query_len == 1 # assert self.max_prefill_seq_len == 0 assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) assert self.seq_start_loc is not None assert self.seq_start_loc.shape == (num_seqs + 1, ) assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) assert self.block_tables is not None assert self.block_tables.shape[0] == num_seqs # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): self.seq_lens[i] += 1 # self.max_decode_seq_len = None ops.advance_step_flashattn(num_seqs=num_seqs, num_queries=num_queries, block_size=block_size, input_tokens=model_input.input_tokens, sampled_token_ids=sampled_token_ids, input_positions=model_input.input_positions, seq_lens=self.seq_lens_tensor, slot_mapping=self.slot_mapping, block_tables=self.block_tables) class VACCMLAMetadataBuilder(AttentionMetadataBuilder[VACCMLAMetadata]): def __init__(self, input_builder: "ModelInputForVACCBuilder"): self.chunked_prefill = True if hasattr(input_builder, 'chunked_prefill'): self.chunked_prefill = input_builder.chunked_prefill self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.input_positions: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 self.has_prefix_cache_hit = False def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ self.input_data = self.input_builder.input_data self.slot_mapping=self.input_data.slot_mapping self.context_lens= self.input_data.context_lens if self.input_data.num_prefill_tokens !=0: self.block_tables = self.input_data.prefill_block_tables else: self.block_tables= self.input_data.decode_block_tables self.input_positions= self.input_data.input_positions self.prefill_seq_lens = seq_lens[0:self.input_data.num_prefills] self.num_prefills = self.input_data.num_prefills self.num_prefill_tokens = self.input_data.num_prefill_tokens self.num_decode_tokens = self.input_data.num_decode_tokens device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 # max_query_len = max(query_lens) # decode_query_lens = query_lens[self.num_prefills:] # if len(decode_query_lens) > 0: # max_decode_query_len = max(decode_query_lens) # else: # max_decode_query_len = 1 # max_prefill_seq_len = max(self.prefill_seq_lens, default=0) # max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, pad=0, dtype=torch.int, device=device, ) # assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) input_positions = async_tensor_h2d(self.input_positions, torch.int, device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, device, self.runner.pin_memory) seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } return VACCMLAMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, input_positions=input_positions, seq_lens_tensor=seq_lens_tensor, # max_query_len=max_query_len, # max_decode_query_len=None, max_prefill_seq_len=None, max_decode_seq_len=None, query_start_loc=query_start_loc_tensor, seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, num_kv_splits=4, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), ) class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **kwargs) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **kwargs) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap ] if any(unsupported_features): raise NotImplementedError( "VACCMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "VACCMLAImpl") def extract_weights(self): weights = {} if hasattr(self, 'W_Q'): weights["W_Q"] = self.W_Q if hasattr(self, 'W_Q_scales'): weights["W_Q_scales"] = self.W_Q_scales if hasattr(self, 'W_QR'): weights['W_QR'] = self.W_QR if hasattr(self, 'W_QR_scales'): weights["W_QR_scales"] = self.W_QR_scales if hasattr(self, 'W_Q_QR'): weights["W_Q_QR"] = self.W_Q_QR if hasattr(self, 'W_Q_QR_scales'): weights["W_Q_QR_scales"] = self.W_Q_QR_scales if hasattr(self, 'W_UK'): weights['W_UK'] = self.W_UK if hasattr(self, 'W_UK_scales'): weights['W_UK_scales'] = self.W_UK_scales if hasattr(self, 'W_Q_UK_scales'): weights['W_Q_UK_scales'] = self.W_Q_UK_scales if hasattr(self, 'W_UV'): weights['W_UV'] = self.W_UV if hasattr(self, 'W_UV_scales'): weights['W_UV_scales'] = self.W_UV_scales if hasattr(self, 'W_UV_O'): weights['W_UV_O'] = self.W_UV_O if hasattr(self, 'W_UV_O_scales'): weights['W_UV_O_scales'] = self.W_UV_O_scales return weights def _forward_prefill( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: VACCMLAMetadata, ) -> torch.Tensor: assert isinstance(attn_metadata, VACCMLAMetadata) kv_nope = self.kv_b_proj(kv_c_normed)[0]\ .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) v = v.contiguous() # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim # v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], # value=0) # attn_output = torch.vacc.scaled_dot_product_attention( # query=q, # key=k, # value=v_padded, # attn_mask=None, # dropout_p=0, # is_causal=True, # is_train=False, # recompute=False, # flash_attention=True, # sm_scale=self.scale # ) # attn_output = attn_output\ # .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ # .reshape(-1, self.num_heads * v.shape[-1]) seq_lens = attn_metadata.prefill_metadata.seq_lens if len(seq_lens) == 1: # Vacc supports different head dim of v and qk. attn_output = torch.vacc.scaled_dot_product_attention( query=q, key=k, value=v, attn_mask=None, dropout_p=0, is_causal=True, is_train=False, recompute=False, flash_attention=False, sm_scale=self.scale ) attn_out = attn_output.view(-1, self.num_heads * v.shape[-1]) else: attn_outs = [] start = 0 for seq in seq_lens: end = start + seq attn_out = torch.vacc.scaled_dot_product_attention( query=q[start:end, :], key=k[start:end, :], value=v[start:end, :], attn_mask=None, dropout_p=0, is_causal=True, is_train=False, recompute=False, flash_attention=False, sm_scale=self.scale ) start = end attn_outs.append(attn_out) attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1]) return self.o_proj(attn_out)[0] def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: VACCMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None B = q_nope.shape[0] q = torch.cat([q_nope, q_pe], dim=-1) o = torch.zeros(B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device) # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) # print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ") kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] # Run MQA using paged_attention # o = torch.vacc.paged_attention( # query=q, # key_cache=kv_c_and_k_pe_cache, # value_cache=kv_c_cache, # block_table=decode_meta.block_tables, # seq_len=decode_meta.seq_lens_tensor, # out=o, # sm_scale=self.scale # ) # Run MQA using spda # t0 = time.time() o = vacc_paged_attention_naive( q, kv_c_and_k_pe_cache, kv_c_cache, block_table = decode_meta.block_tables, # seq_lens = decode_meta.seq_lens_tensor, seq_lens=decode_meta.seq_lens, out = o, sm_scale=self.scale) # print(f'{os.getpid()} paged_atten(seq: {decode_meta.seq_lens}) time: {time.time() - t0}') return self._v_up_proj_and_o_proj(o) def vacc_paged_attention_naive( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, block_table: torch.Tensor, # seq_lens: torch.Tensor, seq_lens: int, out: Optional[torch.Tensor] = None, sm_scale = -1 ) -> torch.Tensor: # gurantee batch=1 perf if len(seq_lens) == 1: k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]] v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]] attn_out = torch.vacc.scaled_dot_product_attention( query=query, key=k, value=v, attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False, sm_scale=sm_scale ) else: # t0 = time.time() attn_outs = [] for i in range(len(seq_lens)): k_slices = key_cache[block_table[i], :, :, :] k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0) k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]] v_slices = value_cache[block_table[i], :, :, :] v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0) v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]] attn_out = torch.vacc.scaled_dot_product_attention( query=query[i:i+1,:,:], key=k, value=v, attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False, sm_scale=sm_scale ) attn_outs.append(attn_out) attn_out = torch.cat(attn_outs, dim=0) # print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}') return attn_out # MLA single op impl def vacc_paged_attention_naive_singleop( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, seq_lens, block_table = None, out: torch.Tensor = None, sm_scale = -1 ) -> torch.Tensor: k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens] v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens].squeeze(1) pe_cache = k[..., 512:].squeeze(1) print(f'q:{query[..., :512].shape} v:{v.shape} pe_cache:{pe_cache.shape}') q_nope_kv_c = torch.einsum("shc,tc->sht", query[..., :512], v) q_pe_k_pe = torch.einsum("shr,tr->sht", query[..., 512:], pe_cache) scores = (q_nope_kv_c + q_pe_k_pe) * sm_scale scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(query) o = torch.einsum("sht,tc->shc", scores, v) return o