init
This commit is contained in:
114
vllm/v1/attention/backends/mla/flashinfer_mla.py
Normal file
114
vllm/v1/attention/backends/mla/flashinfer_mla.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl")
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: Optional[float] = None
|
||||
self.bmm2_scale: Optional[float] = None
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
||||
self.scale)
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
Reference in New Issue
Block a user