################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from dataclasses import dataclass from typing import Optional import torch from vllm.attention import Attention from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.mla import MLAModules from vllm.model_executor.layers.quantization import QuantizationConfig @dataclass class SupaMLAModules(MLAModules): q_a_proj: Optional[torch.nn.Module] @CustomOp.register("supa_multi_head_latent_attention") class SupaMultiHeadLatentAttention(CustomOp): def __init__( self, hidden_size: int, num_heads: int, scale: float, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: Optional[int], kv_lora_rank: int, mla_modules: MLAModules, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa self.q_a_layernorm = mla_modules.q_a_layernorm self.q_b_proj = mla_modules.q_b_proj self.q_proj = mla_modules.q_proj self.kv_a_layernorm = mla_modules.kv_a_layernorm self.kv_b_proj = mla_modules.kv_b_proj self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer self.is_sparse = mla_modules.is_sparse self.q_a_proj = mla_modules.q_a_proj if self.indexer is not None: assert hasattr(self.indexer, "topk_tokens") self.topk_tokens = self.indexer.topk_tokens self.topk_indices_buffer = mla_modules.topk_indices_buffer # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, # the concat_and_cache_mla op requires # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) # i.e. # kv_lora_rank + qk_rope_head_dim == head_size if self.is_sparse: self.mla_attn = Attention( num_heads=self.num_heads, head_size=self.kv_lora_rank + self.qk_rope_head_dim, scale=scale, num_kv_heads=1, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, use_sparse=mla_modules.is_sparse, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=self.kv_b_proj, indexer=self.indexer, ) else: self.mla_attn = Attention( num_heads=self.num_heads, head_size=self.kv_lora_rank + self.qk_rope_head_dim, scale=scale, num_kv_heads=1, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, use_sparse=mla_modules.is_sparse, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=self.kv_b_proj, indexer=self.indexer, # BIREN args for fused MLA rotary_emb=self.rotary_emb, q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, o_proj=self.o_proj, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, kv_a_layernorm=self.kv_a_layernorm, q_a_proj=None if self.q_lora_rank is None else self.q_a_proj, q_a_layernorm=None if self.q_lora_rank is None else self.q_a_layernorm, ) self.prefix = prefix self.debug_layer_idx = int(self.prefix.split(".")[-2]) def forward_native( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: q_c = None kv_lora = None if self.q_lora_rank is not None: assert self.fused_qkv_a_proj is not None, \ "fused_qkv_a_proj is required when q_lora_rank is not None" assert self.q_a_layernorm is not None, \ "q_a_layernorm is required when q_lora_rank is not None" assert self.q_b_proj is not None, \ "q_b_proj is required when q_lora_rank is not None" qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_lora = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) q_c = self.q_a_layernorm(q_c) q = self.q_b_proj(q_c)[0].view(-1, self.num_heads * self.qk_head_dim) else: assert self.kv_a_proj_with_mqa is not None, \ "kv_a_proj_with_mqa is required when q_lora_rank is None" assert self.q_proj is not None, \ "q_proj is required when q_lora_rank is None" kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] kv_lora = kv_lora.view(-1, self.kv_lora_rank + self.qk_rope_head_dim) kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) if self.indexer and self.is_sparse: _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) seq_len = hidden_states.shape[1] attn_out = self.mla_attn(q, kv_c_normed, k_pe, output_shape=(seq_len, self.num_heads * self.v_head_dim)) return self.o_proj(attn_out)[0].unsqueeze(0) def forward_supa( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: return self.mla_attn(hidden_states, positions, hidden_states, output_shape=hidden_states.shape) def forward_oot(self, *args, is_ds_v32: Optional[int], **kwargs): if is_ds_v32: return self.forward_native(*args, **kwargs) else: return self.forward_supa(*args, **kwargs)