308 lines
11 KiB
Python
308 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
|
|
import vllm._custom_ops as ops
|
|
from vllm._ipex_ops import ipex_ops
|
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
|
AttentionMetadataBuilder,
|
|
AttentionType,
|
|
is_quantized_kv_cache)
|
|
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
|
|
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
|
|
from vllm.utils import make_tensor_with_pad
|
|
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
|
|
|
|
|
class CPUMLABackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "CPU_MLA"
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
|
|
return CPUMLAMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
|
|
return CPUMLAMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["MLACommonState"]:
|
|
return MLACommonState
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["CPUMLAImpl"]:
|
|
return CPUMLAImpl
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int, # assumed to be 1 for MLA
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: torch.Tensor,
|
|
dst_kv_cache: torch.Tensor,
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
|
|
|
@staticmethod
|
|
def get_supported_head_sizes() -> List[int]:
|
|
return [576]
|
|
|
|
|
|
@dataclass
|
|
class CPUMLAMetadata(TorchSDPAMetadata):
|
|
# New for MLA
|
|
# Input positions for rotrary embeddings since for MLA the rotary
|
|
# position embeddings are applied inside the attention backend
|
|
input_positions: torch.Tensor = None
|
|
|
|
# required by MLACommonImpl
|
|
is_profile_run: bool = False
|
|
|
|
|
|
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
|
|
|
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
|
self.chunked_prefill = input_builder.chunked_prefill
|
|
self.input_builder = input_builder
|
|
assert not self.chunked_prefill, \
|
|
"chunked prefill is currently not supported"
|
|
|
|
def prepare(self):
|
|
self.input_data = self.input_builder.input_data
|
|
|
|
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
|
|
input_data = self.input_data
|
|
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
|
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
|
slot_mapping = torch.tensor(input_data.slot_mapping,
|
|
dtype=torch.long,
|
|
device="cpu")
|
|
|
|
# metadata for prefill
|
|
if input_data.num_prefills > 0:
|
|
query_lens_tensor = torch.tensor(prefill_query_lens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
torch.cumsum(query_lens_tensor,
|
|
dim=0,
|
|
dtype=torch.int32,
|
|
out=query_start_loc[1:])
|
|
torch.cumsum(kv_lens_tensor,
|
|
dim=0,
|
|
dtype=torch.int32,
|
|
out=kv_start_loc[1:])
|
|
max_query_len = max(prefill_query_lens)
|
|
max_kv_len = max(prefill_seq_lens)
|
|
|
|
# for chunked-prefill
|
|
if self.chunked_prefill:
|
|
prefill_block_tables = make_tensor_with_pad(
|
|
self.input_data.prefill_block_tables,
|
|
pad=0,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
)
|
|
else:
|
|
prefill_block_tables = None
|
|
|
|
else:
|
|
query_start_loc = None
|
|
kv_start_loc = None
|
|
max_query_len = None
|
|
max_kv_len = None
|
|
prefill_block_tables = None
|
|
|
|
# metadata for decode
|
|
if input_data.num_decode_tokens != 0:
|
|
seq_lens_tensor = torch.tensor(
|
|
input_data.seq_lens[input_data.num_prefills:],
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
)
|
|
block_tables = make_tensor_with_pad(
|
|
self.input_data.decode_block_tables,
|
|
pad=0,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
)
|
|
else:
|
|
block_tables = torch.tensor([])
|
|
seq_lens_tensor = torch.tensor(
|
|
input_data.seq_lens[:input_data.num_prefills],
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
)
|
|
|
|
# For multi-modal models
|
|
placeholder_index_maps = None
|
|
if len(input_data.multi_modal_inputs_list) != 0:
|
|
placeholder_index_maps = {
|
|
modality: placeholder_map.index_map()
|
|
for modality, placeholder_map in
|
|
input_data.multi_modal_placeholder_maps.items()
|
|
}
|
|
|
|
return CPUMLAMetadata(
|
|
chunked_prefill=self.chunked_prefill,
|
|
seq_lens=prefill_seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=max_query_len,
|
|
max_kv_len=max_kv_len,
|
|
prefill_query_start_loc=query_start_loc,
|
|
kv_start_loc=kv_start_loc,
|
|
max_decode_seq_len=input_data.max_decode_seq_len,
|
|
num_prefills=input_data.num_prefills,
|
|
num_prefill_tokens=input_data.num_prefill_tokens,
|
|
num_decode_tokens=input_data.num_decode_tokens,
|
|
block_tables=block_tables,
|
|
prefill_block_tables=prefill_block_tables,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
enable_kv_scales_calculation=False,
|
|
input_positions=torch.tensor([self.input_data.input_positions]))
|
|
|
|
|
|
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[Dict[str, Any]],
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
# MLA Specific Arguments
|
|
**mla_args) -> None:
|
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
blocksparse_params, logits_soft_cap, attn_type,
|
|
kv_sharing_target_layer_name, **mla_args)
|
|
|
|
unsupported_features = [
|
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
|
]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"CPUMLAImpl does not support one of the following: "
|
|
"alibi_slopes, sliding_window, blocksparse_params, "
|
|
"logits_soft_cap")
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"CPUMLAImpl")
|
|
|
|
# states is implemented.
|
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
raise NotImplementedError(
|
|
"CPUMLAImpl with FP8 KV cache not yet supported")
|
|
|
|
def _forward_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
|
) -> torch.Tensor:
|
|
|
|
prefill_metadata = attn_metadata.prefill_metadata
|
|
assert prefill_metadata is not None
|
|
|
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope, v = kv_nope\
|
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
|
|
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
|
# v with 0s to match the qk head dim
|
|
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
|
value=0)
|
|
|
|
output = torch.empty_like(q)
|
|
ipex_ops.varlen_attention(
|
|
query=q,
|
|
key=k,
|
|
value=v_padded,
|
|
out=output,
|
|
seqlen_q=prefill_metadata.prefill_query_start_loc,
|
|
seqlen_k=prefill_metadata.prefill_query_start_loc,
|
|
max_seqlen_q=prefill_metadata.max_query_len,
|
|
max_seqlen_k=prefill_metadata.max_query_len,
|
|
pdropout=0.0,
|
|
softmax_scale=self.scale,
|
|
zero_tensors=False,
|
|
is_causal=True,
|
|
return_softmax=False,
|
|
gen_=None,
|
|
logits_soft_cap=0.0,
|
|
window_size_left=-1,
|
|
window_size_right=-1,
|
|
alibi_slopes=None,
|
|
)
|
|
|
|
# remove padding
|
|
output = output.view(-1, self.num_heads,
|
|
q.shape[-1])[..., :v.shape[-1]]
|
|
return output.reshape(-1, self.num_heads * v.shape[-1])
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q_nope: torch.Tensor,
|
|
q_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
|
) -> torch.Tensor:
|
|
assert kv_c_and_k_pe_cache.numel() > 0
|
|
|
|
decode_meta = attn_metadata.decode_metadata
|
|
assert decode_meta is not None
|
|
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
|
|
|
|
# Run MQA
|
|
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
|
decode_meta.block_tables,
|
|
decode_meta.seq_lens_tensor)
|
|
return self._v_up_proj(o)
|