# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py from typing import Optional, Tuple import torch from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False if current_platform.is_rocm(): import flash_mla_cuda _flashmla_C_AVAILABLE = True def is_flashmla_supported() -> Tuple[bool, Optional[str]]: """ Return: is_supported_flag, unsupported_reason (optional). """ if not (current_platform.is_cuda() or current_platform.is_rocm()): return False, "FlashMLA is supported on CUDA and ROCM devices." if current_platform.get_device_capability()[0] != 9: return False, "FlashMLA is only supported on Hopper devices." if not _flashmla_C_AVAILABLE: return False, "vllm._flashmla_C is not available, likely was not "\ "compiled due to insufficient nvcc version or a supported arch "\ "(only sm90a currently) was not in the list of target arches to "\ "compile for." return True, None def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Return: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ if current_platform.is_rocm(): return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) else: return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, k_scale = None, kv_cache_dtype = "auto", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head_dim of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) if current_platform.is_rocm(): if kv_cache_dtype == "fp8": out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, k_scale, "fp8_e4m3", ) return out, softmax_lse out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, ) else: out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, ) return out, softmax_lse # # TODO: Add fake functions # # @register_fake("_flashmla_C::get_mla_metadata") # def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # # @register_fake("_flashmla_C::fwd_kvcache_mla") # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... #