# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional, Union import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) logger = init_logger(__name__) FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 class FlashInferMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @staticmethod def get_impl_cls() -> type["FlashInferMLAImpl"]: return FlashInferMLAImpl g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda", ) class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashInferMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "FlashInferMLAImpl") self._workspace_buffer = g_fi_workspace self.bmm1_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None def _forward_decode( self, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None if isinstance(q, tuple): q_nope, q_pe = q q = torch.cat([q_nope, q_pe], dim=-1) # trtllm API requires extra dimension q_len_per_request for MTP q = q.unsqueeze(1) if self.bmm1_scale is None: self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * self.scale) if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float o = trtllm_batch_decode_with_kv_cache_mla( query=q, kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), workspace_buffer=self._workspace_buffer, qk_nope_head_dim=self.qk_nope_head_dim, kv_lora_rank=self.kv_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim, block_tables=attn_metadata.decode.block_table, seq_lens=attn_metadata.decode.seq_lens, max_seq_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, ) # TODO: Return LSE pending support from Flashinfer API: # https://github.com/flashinfer-ai/flashinfer/pull/1566 return o, None