202 lines
7.9 KiB
Python
202 lines
7.9 KiB
Python
|
|
import functools
|
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
|
Type, TypeVar)
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.attention.backends.abstract import AttentionLayer
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|
AttentionMetadata,
|
|
AttentionMetadataBuilder,
|
|
AttentionState, MLAAttentionImpl)
|
|
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
|
|
|
|
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
LinearBase, RowParallelLinear,
|
|
UnquantizedLinearMethod)
|
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
|
|
|
from vllm.model_executor.layers.rotary_embedding import (
|
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
try:
|
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
is_vllm_fa = True
|
|
except ImportError:
|
|
is_vllm_fa = False
|
|
try:
|
|
# For rocm use upstream flash attention
|
|
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
|
|
except ImportError:
|
|
flash_attn_varlen_func = None
|
|
|
|
T = TypeVar("T", bound="MLACommonMetadata")
|
|
|
|
|
|
class MLACommonImpl():
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
# blocksparse_params: Optional[Dict[str, Any]],
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
# MLA Specific Arguments
|
|
q_lora_rank: Optional[int],
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
qk_head_dim: int,
|
|
v_head_dim: int,
|
|
rotary_emb: RotaryEmbedding,
|
|
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
|
# attention backend perspective we rely on the layer to pass in the
|
|
# correct matrix
|
|
q_proj: ColumnParallelLinear,
|
|
kv_b_proj: ColumnParallelLinear,
|
|
o_proj: RowParallelLinear,
|
|
positions: torch.Tensor = None,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
|
|
self.rotary_emb = rotary_emb
|
|
self.use_yarn_rope = isinstance(rotary_emb,
|
|
DeepseekScalingRotaryEmbedding)
|
|
self.q_proj = q_proj
|
|
self.kv_b_proj = kv_b_proj
|
|
self.o_proj = o_proj
|
|
self.positions = positions
|
|
|
|
self.triton_fa_func = triton_attention
|
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
|
# latter has an additional parameter to control FA2 vs FA3
|
|
self.flash_attn_varlen_func = flash_attn_varlen_func
|
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
|
if self.vllm_flash_attn_version is not None:
|
|
self.flash_attn_varlen_func = \
|
|
functools.partial(flash_attn_varlen_func,
|
|
fa_version=self.vllm_flash_attn_version)
|
|
|
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
|
# v with 0s to match the qk head dim for attention backends that do
|
|
# not support different headdims
|
|
# We don't need to pad V if we are on a hopper system with FA3
|
|
self._pad_v = self.vllm_flash_attn_version is None or not (
|
|
self.vllm_flash_attn_version == 3
|
|
and current_platform.get_device_capability()[0] == 9)
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
|
k_c_normed: torch.Tensor, # key in unified attn
|
|
k_pe: torch.Tensor, # value in unified attn
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if output is not None:
|
|
raise NotImplementedError(
|
|
"output is not yet supported for MLAImplBase")
|
|
|
|
# if attn_metadata.is_profile_run and \
|
|
# attn_metadata.context_chunk_workspace is not None:
|
|
# # During the profile run try to simulate to worse case output size
|
|
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
|
# # since this can be large
|
|
# _ = torch.empty(
|
|
# (attn_metadata.context_chunk_workspace.shape[0],
|
|
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
|
# device=k_c_normed.device,
|
|
# dtype=k_c_normed.dtype,
|
|
# )
|
|
|
|
has_decode = attn_metadata.decode_metadata is not None
|
|
has_prefill = attn_metadata.prefill_metadata is not None
|
|
|
|
# Restore head dim (for rotary embedding)
|
|
k_pe = k_pe.unsqueeze(1)
|
|
# assert hasattr(attn_metadata, "input_positions")
|
|
if self.positions is not None:
|
|
positions = self.positions
|
|
elif hasattr(attn_metadata, "input_positions"):
|
|
positions = attn_metadata.input_positions
|
|
else:
|
|
raise ValueError('no positions')
|
|
|
|
|
|
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
|
|
|
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
|
decode_k_pe = k_pe[num_prefill_tokens:]
|
|
decode_input_positions = \
|
|
positions[num_prefill_tokens:]
|
|
|
|
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
|
prefill_k_pe = k_pe[:num_prefill_tokens]
|
|
prefill_input_positions = \
|
|
positions[:num_prefill_tokens]
|
|
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
|
|
|
if has_decode:
|
|
decode_ql_nope, decode_q_pe = \
|
|
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
|
decode_input_positions, decode_q_pe, decode_k_pe)
|
|
|
|
if has_prefill:
|
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
|
.view(-1, self.num_heads, self.qk_head_dim)
|
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
|
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
|
|
|
# write the latent and rope to kv cache
|
|
if kv_cache.numel() > 0:
|
|
ops.concat_and_cache_mla(
|
|
k_c_normed,
|
|
k_pe.squeeze(1),
|
|
kv_cache,
|
|
attn_metadata.slot_mapping.flatten(),
|
|
kv_cache_dtype=self.kv_cache_dtype,
|
|
scale=layer._k_scale,
|
|
)
|
|
|
|
# output = torch.empty(attn_metadata.num_prefill_tokens +
|
|
# attn_metadata.num_decode_tokens,
|
|
# self.o_proj.output_size,
|
|
# device=hidden_states_or_q_c.device,
|
|
# dtype=hidden_states_or_q_c.dtype)
|
|
if has_prefill:
|
|
return self._forward_prefill(
|
|
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
|
attn_metadata)
|
|
|
|
if has_decode:
|
|
return self._forward_decode(
|
|
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
|
|
|
assert False, "mla forward need prefill or decode function"
|
|
return None |