# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py import torch from vllm.logger import init_logger from vllm.platforms import current_platform from vllm import _custom_ops as ops logger = init_logger(__name__) if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False if current_platform.is_cuda(): try: import vllm._flashmla_extension_C # noqa: F401 _flashmla_extension_C_AVAILABLE = True except ImportError: _flashmla_extension_C_AVAILABLE = False else: _flashmla_extension_C_AVAILABLE = False def _is_flashmla_available() -> tuple[bool, str | None]: if not _flashmla_C_AVAILABLE: return ( False, "vllm._flashmla_C is not available, likely was not " "compiled due to insufficient nvcc version or a supported arch " "was not in the list of target arches to compile for.", ) if not _flashmla_extension_C_AVAILABLE: return ( False, "vllm._flashmla_extension_C is not available, likely " "was not compiled due to a build error.", ) return True, None def is_flashmla_dense_supported() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ is_available, maybe_reason = _is_flashmla_available() if not is_available: return False, maybe_reason if not current_platform.is_device_capability_family(90): return False, "FlashMLA Dense is only supported on Hopper devices." return True, None def is_flashmla_sparse_supported() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ is_available, maybe_reason = _is_flashmla_available() if not is_available: return False, maybe_reason if not ( current_platform.is_device_capability_family(90) or current_platform.is_device_capability_family(100) ): return ( False, "FlashMLA Sparse is only supported on Hopper and Blackwell devices.", ) return True, None def _raise_flashmla_unavailable(*_args, **_kwargs): _, reason = _is_flashmla_available() raise RuntimeError(reason or "FlashMLA is not available") if _is_flashmla_available()[0]: from vllm.third_party.flashmla.flash_mla_interface import ( # noqa: F401 FlashMLASchedMeta, flash_attn_varlen_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata, ) else: class FlashMLASchedMeta: # type: ignore[no-redef] pass flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment] flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment] get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment] def get_mla_metadata_dense_fp8( cache_seqlens: torch.Tensor, num_q_tokens_per_head_k: int, num_heads_k: int, ) -> tuple[torch.Tensor, torch.Tensor]: if not _is_flashmla_available()[0]: _raise_flashmla_unavailable() return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8( cache_seqlens, num_q_tokens_per_head_k, num_heads_k, ) def flash_mla_with_kvcache_fp8( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, descale_q: torch.Tensor | None = None, descale_k: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if not _is_flashmla_available()[0]: _raise_flashmla_unavailable() if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k, ) return out, softmax_lse def flash_mla_sparse_prefill( q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int = 512, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sparse attention prefill kernel Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512 Returns: - (output, max_logits, lse) About the definition of output, max_logits and lse, please refer to README.md - output: [s_q, h_q, d_v], bfloat16 - max_logits: [s_q, h_q], float - lse: [s_q, h_q], float, 2-based log-sum-exp """ results = ops.sparse_prefill_fwd(q, kv, indices,sm_scale, d_v) return results # # TODO: Add fake functions # # @register_fake("_flashmla_C::get_mla_metadata") # def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # # @register_fake("_flashmla_C::fwd_kvcache_mla") # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... #