156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
if current_platform.is_cuda():
|
|
try:
|
|
import vllm._flashmla_C # noqa: F401
|
|
_flashmla_C_AVAILABLE = True
|
|
except ImportError:
|
|
_flashmla_C_AVAILABLE = False
|
|
else:
|
|
_flashmla_C_AVAILABLE = False
|
|
|
|
if current_platform.is_rocm():
|
|
import flash_mla_cuda
|
|
_flashmla_C_AVAILABLE = True
|
|
|
|
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
|
"""
|
|
Return: is_supported_flag, unsupported_reason (optional).
|
|
"""
|
|
if not (current_platform.is_cuda() or current_platform.is_rocm()):
|
|
return False, "FlashMLA is supported on CUDA and ROCM devices."
|
|
if current_platform.get_device_capability()[0] != 9:
|
|
return False, "FlashMLA is only supported on Hopper devices."
|
|
if not _flashmla_C_AVAILABLE:
|
|
return False, "vllm._flashmla_C is not available, likely was not "\
|
|
"compiled due to insufficient nvcc version or a supported arch "\
|
|
"(only sm90a currently) was not in the list of target arches to "\
|
|
"compile for."
|
|
return True, None
|
|
|
|
|
|
def get_mla_metadata(
|
|
cache_seqlens: torch.Tensor,
|
|
num_heads_per_head_k: int,
|
|
num_heads_k: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Arguments:
|
|
cache_seqlens: (batch_size), dtype torch.int32.
|
|
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
|
num_heads_k: num_heads_k.
|
|
|
|
Return:
|
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
|
dtype torch.int32.
|
|
num_splits: (batch_size + 1), dtype torch.int32.
|
|
"""
|
|
if current_platform.is_rocm():
|
|
return flash_mla_cuda.get_mla_metadata(cache_seqlens,
|
|
num_heads_per_head_k,
|
|
num_heads_k)
|
|
else:
|
|
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
|
num_heads_per_head_k,
|
|
num_heads_k)
|
|
|
|
|
|
def flash_mla_with_kvcache(
|
|
q: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
head_dim_v: int,
|
|
tile_scheduler_metadata: torch.Tensor,
|
|
num_splits: torch.Tensor,
|
|
softmax_scale: Optional[float] = None,
|
|
causal: bool = False,
|
|
k_scale = None,
|
|
kv_cache_dtype = "auto",
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Arguments:
|
|
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
|
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
|
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
|
cache_seqlens: (batch_size), torch.int32.
|
|
head_dim_v: Head_dim of v.
|
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
|
torch.int32, return by get_mla_metadata.
|
|
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(head_dim).
|
|
causal: bool. Whether to apply causal attention mask.
|
|
|
|
Return:
|
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
|
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
|
"""
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1]**(-0.5)
|
|
if current_platform.is_rocm():
|
|
if kv_cache_dtype == "fp8":
|
|
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
|
|
q,
|
|
k_cache,
|
|
None,
|
|
head_dim_v,
|
|
cache_seqlens,
|
|
block_table,
|
|
softmax_scale,
|
|
causal,
|
|
tile_scheduler_metadata,
|
|
num_splits,
|
|
k_scale,
|
|
"fp8_e4m3",
|
|
)
|
|
return out, softmax_lse
|
|
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
|
q,
|
|
k_cache,
|
|
None,
|
|
head_dim_v,
|
|
cache_seqlens,
|
|
block_table,
|
|
softmax_scale,
|
|
causal,
|
|
tile_scheduler_metadata,
|
|
num_splits,
|
|
)
|
|
else:
|
|
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
|
q,
|
|
k_cache,
|
|
None,
|
|
head_dim_v,
|
|
cache_seqlens,
|
|
block_table,
|
|
softmax_scale,
|
|
causal,
|
|
tile_scheduler_metadata,
|
|
num_splits,
|
|
)
|
|
return out, softmax_lse
|
|
|
|
|
|
#
|
|
# TODO: Add fake functions
|
|
#
|
|
# @register_fake("_flashmla_C::get_mla_metadata")
|
|
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# return ....
|
|
#
|
|
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
|
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# return ....
|
|
# |