115 lines
4.0 KiB
Python
115 lines
4.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
|
|
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
|
MLACommonImpl,
|
|
MLACommonMetadata)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
|
|
|
|
|
class FlashInferMLABackend(MLACommonBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "FLASHINFER_MLA"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
|
return FlashInferMLAImpl
|
|
|
|
|
|
g_fi_workspace = torch.zeros(
|
|
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
|
dtype=torch.uint8,
|
|
device="cuda",
|
|
)
|
|
|
|
|
|
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[list[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
# MLA Specific Arguments
|
|
**mla_args) -> None:
|
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
logits_soft_cap, attn_type,
|
|
kv_sharing_target_layer_name, **mla_args)
|
|
|
|
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"FlashInferMLAImpl does not support one of the following: "
|
|
"alibi_slopes, sliding_window, logits_soft_cap")
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"FlashInferMLAImpl")
|
|
|
|
self._workspace_buffer = g_fi_workspace
|
|
self.bmm1_scale: Optional[float] = None
|
|
self.bmm2_scale: Optional[float] = None
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: MLACommonMetadata,
|
|
layer: AttentionLayer,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
assert kv_c_and_k_pe_cache.numel() > 0
|
|
assert attn_metadata.decode is not None
|
|
|
|
if isinstance(q, tuple):
|
|
q_nope, q_pe = q
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
|
|
# trtllm API requires extra dimension q_len_per_request for MTP
|
|
q = q.unsqueeze(1)
|
|
|
|
if self.bmm1_scale is None:
|
|
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
|
self.scale)
|
|
if self.bmm2_scale is None:
|
|
self.bmm2_scale = layer._v_scale_float
|
|
|
|
o = trtllm_batch_decode_with_kv_cache_mla(
|
|
query=q,
|
|
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
|
workspace_buffer=self._workspace_buffer,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
block_tables=attn_metadata.decode.block_table,
|
|
seq_lens=attn_metadata.decode.seq_lens,
|
|
max_seq_len=attn_metadata.max_seq_len,
|
|
bmm1_scale=self.bmm1_scale,
|
|
bmm2_scale=self.bmm2_scale,
|
|
)
|
|
|
|
# TODO: Return LSE pending support from Flashinfer API:
|
|
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
|
return o, None
|