211 lines
8.3 KiB
Python
211 lines
8.3 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.attention import Attention
|
|
from vllm.config import CacheConfig
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.mla import MLAModules
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
|
|
|
|
@dataclass
|
|
class SupaMLAModules(MLAModules):
|
|
q_a_proj: Optional[torch.nn.Module]
|
|
|
|
|
|
@CustomOp.register("supa_multi_head_latent_attention")
|
|
class SupaMultiHeadLatentAttention(CustomOp):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
scale: float,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: Optional[int],
|
|
kv_lora_rank: int,
|
|
mla_modules: MLAModules,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.num_heads = num_heads
|
|
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
|
|
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
|
|
self.q_a_layernorm = mla_modules.q_a_layernorm
|
|
self.q_b_proj = mla_modules.q_b_proj
|
|
self.q_proj = mla_modules.q_proj
|
|
self.kv_a_layernorm = mla_modules.kv_a_layernorm
|
|
self.kv_b_proj = mla_modules.kv_b_proj
|
|
self.rotary_emb = mla_modules.rotary_emb
|
|
self.o_proj = mla_modules.o_proj
|
|
self.indexer = mla_modules.indexer
|
|
self.is_sparse = mla_modules.is_sparse
|
|
self.q_a_proj = mla_modules.q_a_proj
|
|
|
|
if self.indexer is not None:
|
|
assert hasattr(self.indexer, "topk_tokens")
|
|
self.topk_tokens = self.indexer.topk_tokens
|
|
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
|
|
|
# In the MLA backend, kv_cache includes both k_c and
|
|
# pe (i.e. decoupled position embeddings). In particular,
|
|
# the concat_and_cache_mla op requires
|
|
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
|
# i.e.
|
|
# kv_lora_rank + qk_rope_head_dim == head_size
|
|
if self.is_sparse:
|
|
self.mla_attn = Attention(
|
|
num_heads=self.num_heads,
|
|
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
|
scale=scale,
|
|
num_kv_heads=1,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
use_mla=True,
|
|
use_sparse=mla_modules.is_sparse,
|
|
# MLA Args
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
qk_head_dim=self.qk_head_dim,
|
|
v_head_dim=self.v_head_dim,
|
|
kv_b_proj=self.kv_b_proj,
|
|
indexer=self.indexer,
|
|
)
|
|
else:
|
|
self.mla_attn = Attention(
|
|
num_heads=self.num_heads,
|
|
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
|
scale=scale,
|
|
num_kv_heads=1,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
use_mla=True,
|
|
use_sparse=mla_modules.is_sparse,
|
|
# MLA Args
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
qk_head_dim=self.qk_head_dim,
|
|
v_head_dim=self.v_head_dim,
|
|
kv_b_proj=self.kv_b_proj,
|
|
indexer=self.indexer,
|
|
# BIREN args for fused MLA
|
|
rotary_emb=self.rotary_emb,
|
|
q_proj=self.q_proj
|
|
if self.q_lora_rank is None else self.q_b_proj,
|
|
o_proj=self.o_proj,
|
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
|
kv_a_layernorm=self.kv_a_layernorm,
|
|
q_a_proj=None if self.q_lora_rank is None else self.q_a_proj,
|
|
q_a_layernorm=None
|
|
if self.q_lora_rank is None else self.q_a_layernorm,
|
|
)
|
|
|
|
self.prefix = prefix
|
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
|
|
|
def forward_native(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
q_c = None
|
|
kv_lora = None
|
|
|
|
if self.q_lora_rank is not None:
|
|
assert self.fused_qkv_a_proj is not None, \
|
|
"fused_qkv_a_proj is required when q_lora_rank is not None"
|
|
assert self.q_a_layernorm is not None, \
|
|
"q_a_layernorm is required when q_lora_rank is not None"
|
|
assert self.q_b_proj is not None, \
|
|
"q_b_proj is required when q_lora_rank is not None"
|
|
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
|
q_c, kv_lora = qkv_lora.split(
|
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
|
dim=-1,
|
|
)
|
|
q_c = self.q_a_layernorm(q_c)
|
|
q = self.q_b_proj(q_c)[0].view(-1,
|
|
self.num_heads * self.qk_head_dim)
|
|
else:
|
|
assert self.kv_a_proj_with_mqa is not None, \
|
|
"kv_a_proj_with_mqa is required when q_lora_rank is None"
|
|
assert self.q_proj is not None, \
|
|
"q_proj is required when q_lora_rank is None"
|
|
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
q = self.q_proj(hidden_states)[0]
|
|
|
|
kv_lora = kv_lora.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)
|
|
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
|
|
dim=-1)
|
|
kv_c_normed = self.kv_a_layernorm(kv_c)
|
|
|
|
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
|
# Add head dim of 1 to k_pe
|
|
k_pe = k_pe.unsqueeze(1)
|
|
|
|
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
|
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
|
|
|
if self.indexer and self.is_sparse:
|
|
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
|
self.rotary_emb)
|
|
|
|
seq_len = hidden_states.shape[1]
|
|
attn_out = self.mla_attn(q,
|
|
kv_c_normed,
|
|
k_pe,
|
|
output_shape=(seq_len, self.num_heads *
|
|
self.v_head_dim))
|
|
return self.o_proj(attn_out)[0].unsqueeze(0)
|
|
|
|
def forward_supa(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return self.mla_attn(hidden_states,
|
|
positions,
|
|
hidden_states,
|
|
output_shape=hidden_states.shape)
|
|
|
|
def forward_oot(self, *args, is_ds_v32: Optional[int], **kwargs):
|
|
if is_ds_v32:
|
|
return self.forward_native(*args, **kwargs)
|
|
else:
|
|
return self.forward_supa(*args, **kwargs)
|