# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py from typing import Optional, Tuple import torch from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) # if current_platform.is_cuda(): # try: # import vllm._flashmla_C # noqa: F401 # _flashmla_C_AVAILABLE = True # except ImportError: # _flashmla_C_AVAILABLE = False # else: # _flashmla_C_AVAILABLE = False try : import flash_mla _flashmla_AVAILABLE = True except ImportError as e: logger.warning("Failed to import from flash_mla with %r on MACA Platform", e) _flashmla_AVAILABLE = False def is_flashmla_supported() -> Tuple[bool, Optional[str]]: """ Return: is_supported_flag, unsupported_reason (optional). """ # if not current_platform.is_cuda(): # return False, "FlashMLA is only supported on CUDA devices." # if current_platform.get_device_capability()[0] != 9: # return False, "FlashMLA is only supported on Hopper devices." # if not _flashmla_C_AVAILABLE: # return False, "vllm._flashmla_C is not available, likely was not "\ # "compiled due to insufficient nvcc version or a supported arch "\ # "(only sm90a currently) was not in the list of target arches to "\ # "compile for." if not _flashmla_AVAILABLE: return False, "flash_mla is not available" return True, None def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Return: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ # return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, # num_heads_per_head_k, # num_heads_k) return flash_mla.flash_mla_interface.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head_dim of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ # if softmax_scale is None: # softmax_scale = q.shape[-1]**(-0.5) # out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( # q, # k_cache, # None, # head_dim_v, # cache_seqlens, # block_table, # softmax_scale, # causal, # tile_scheduler_metadata, # num_splits, # ) out, softmax_lse = flash_mla.flash_mla_interface.flash_mla_with_kvcache( q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata, num_splits, softmax_scale, causal, ) return out, softmax_lse # # TODO: Add fake functions # # @register_fake("_flashmla_C::get_mla_metadata") # def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # # @register_fake("_flashmla_C::fwd_kvcache_mla") # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... #