Files

202 lines
7.9 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import functools
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar)
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, MLAAttentionImpl)
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
is_vllm_fa = False
try:
# For rocm use upstream flash attention
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
T = TypeVar("T", bound="MLACommonMetadata")
class MLACommonImpl():
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
# blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
positions: torch.Tensor = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.positions = positions
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
# if attn_metadata.is_profile_run and \
# attn_metadata.context_chunk_workspace is not None:
# # During the profile run try to simulate to worse case output size
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# # since this can be large
# _ = torch.empty(
# (attn_metadata.context_chunk_workspace.shape[0],
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
# device=k_c_normed.device,
# dtype=k_c_normed.dtype,
# )
has_decode = attn_metadata.decode_metadata is not None
has_prefill = attn_metadata.prefill_metadata is not None
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
# assert hasattr(attn_metadata, "input_positions")
if self.positions is not None:
positions = self.positions
elif hasattr(attn_metadata, "input_positions"):
positions = attn_metadata.input_positions
else:
raise ValueError('no positions')
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
decode_k_pe = k_pe[num_prefill_tokens:]
decode_input_positions = \
positions[num_prefill_tokens:]
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
prefill_k_pe = k_pe[:num_prefill_tokens]
prefill_input_positions = \
positions[:num_prefill_tokens]
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode:
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
# output = torch.empty(attn_metadata.num_prefill_tokens +
# attn_metadata.num_decode_tokens,
# self.o_proj.output_size,
# device=hidden_states_or_q_c.device,
# dtype=hidden_states_or_q_c.dtype)
if has_prefill:
return self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
return self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
assert False, "mla forward need prefill or decode function"
return None