init
This commit is contained in:
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, MLAAttentionImpl)
|
||||
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
|
||||
|
||||
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
is_vllm_fa = False
|
||||
try:
|
||||
# For rocm use upstream flash attention
|
||||
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
T = TypeVar("T", bound="MLACommonMetadata")
|
||||
|
||||
|
||||
class MLACommonImpl():
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
# blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
positions: torch.Tensor = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.positions = positions
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
# if attn_metadata.is_profile_run and \
|
||||
# attn_metadata.context_chunk_workspace is not None:
|
||||
# # During the profile run try to simulate to worse case output size
|
||||
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# # since this can be large
|
||||
# _ = torch.empty(
|
||||
# (attn_metadata.context_chunk_workspace.shape[0],
|
||||
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
||||
# device=k_c_normed.device,
|
||||
# dtype=k_c_normed.dtype,
|
||||
# )
|
||||
|
||||
has_decode = attn_metadata.decode_metadata is not None
|
||||
has_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
# assert hasattr(attn_metadata, "input_positions")
|
||||
if self.positions is not None:
|
||||
positions = self.positions
|
||||
elif hasattr(attn_metadata, "input_positions"):
|
||||
positions = attn_metadata.input_positions
|
||||
else:
|
||||
raise ValueError('no positions')
|
||||
|
||||
|
||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||
decode_input_positions = \
|
||||
positions[num_prefill_tokens:]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||
prefill_input_positions = \
|
||||
positions[:num_prefill_tokens]
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
# output = torch.empty(attn_metadata.num_prefill_tokens +
|
||||
# attn_metadata.num_decode_tokens,
|
||||
# self.o_proj.output_size,
|
||||
# device=hidden_states_or_q_c.device,
|
||||
# dtype=hidden_states_or_q_c.dtype)
|
||||
if has_prefill:
|
||||
return self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
return self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
assert False, "mla forward need prefill or decode function"
|
||||
return None
|
||||
Reference in New Issue
Block a user