Files

139 lines
4.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
# if current_platform.is_cuda():
# try:
# import vllm._flashmla_C # noqa: F401
# _flashmla_C_AVAILABLE = True
# except ImportError:
# _flashmla_C_AVAILABLE = False
# else:
# _flashmla_C_AVAILABLE = False
try :
import flash_mla
_flashmla_AVAILABLE = True
except ImportError as e:
logger.warning("Failed to import from flash_mla with %r on MACA Platform", e)
_flashmla_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
# if not current_platform.is_cuda():
# return False, "FlashMLA is only supported on CUDA devices."
# if current_platform.get_device_capability()[0] != 9:
# return False, "FlashMLA is only supported on Hopper devices."
# if not _flashmla_C_AVAILABLE:
# return False, "vllm._flashmla_C is not available, likely was not "\
# "compiled due to insufficient nvcc version or a supported arch "\
# "(only sm90a currently) was not in the list of target arches to "\
# "compile for."
if not _flashmla_AVAILABLE:
return False, "flash_mla is not available"
return True, None
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
# return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
# num_heads_per_head_k,
# num_heads_k)
return flash_mla.flash_mla_interface.get_mla_metadata(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
# if softmax_scale is None:
# softmax_scale = q.shape[-1]**(-0.5)
# out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
# q,
# k_cache,
# None,
# head_dim_v,
# cache_seqlens,
# block_table,
# softmax_scale,
# causal,
# tile_scheduler_metadata,
# num_splits,
# )
out, softmax_lse = flash_mla.flash_mla_interface.flash_mla_with_kvcache(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
softmax_scale,
causal,
)
return out, softmax_lse
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#