# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import torch import torch.nn.functional as F from einops import rearrange from torch import nn from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.lightning_attn import ( lightning_attention, linear_decode_forward_triton, ) from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata class MiniMaxText01RMSNormTP(CustomOp): name = "MiniMaxText01RMSNormTP" def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.tp_world = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world))) self.weight.weight_loader = self.weight_loader self.variance_epsilon = eps return @staticmethod def weight_loader( param: nn.Parameter, loaded_weight: torch.Tensor, ) -> None: tp_world = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() shard_size = loaded_weight.shape[0] // tp_world shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) param.data.copy_(loaded_weight[shard]) return def _forward( self, x: torch.Tensor, ) -> torch.Tensor: orig_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) if self.tp_world > 1: variance = tensor_model_parallel_all_reduce(variance) / self.tp_world x = x * torch.rsqrt(variance + self.variance_epsilon) x = (x * self.weight).to(orig_dtype) return x def forward( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert residual is None, "RMSNorm does not support residual connection." return self._forward(x) class MiniMaxText01LinearKernel: @staticmethod def jit_linear_forward_prefix( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kv_caches: torch.Tensor, slope_rate: torch.Tensor, block_size: int, layer_idx: int | None = None, **kwargs, ) -> torch.Tensor: slope_rate = slope_rate.to(torch.float32) should_pad_dim = q.dim() == 3 if should_pad_dim: q = q.unsqueeze(0) k = k.unsqueeze(0) v = v.unsqueeze(0) b, h, n, d = q.shape e = d kv_history = kv_caches.reshape(1, h, d, e).contiguous() output, kv_history = lightning_attention( q, k, v, slope_rate, block_size=block_size, kv_history=kv_history ) kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) assert output.shape[0] == 1, "batch size must be 1" return rearrange(output.squeeze(0), "h n d -> n (h d)") class MiniMaxText01LinearAttention(nn.Module, MambaBase): @property def mamba_type(self) -> str: return "linear_attention" def get_state_dtype(self) -> tuple[torch.dtype]: assert self.model_config is not None assert self.cache_config is not None return MambaStateDtypeCalculator.linear_attention_state_dtype( self.model_config.dtype, self.cache_config.mamba_cache_dtype, ) def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: return MambaStateShapeCalculator.linear_attention_state_shape( num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim ) def __init__( self, hidden_size: int, hidden_inner_size: int, num_heads: int, head_dim: int, max_position: int, block_size: int, num_hidden_layer: int, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, layer_idx: int = 0, linear_layer_idx: int = 0, prefix: str = "linear_attn", ) -> None: super().__init__() self.layer_idx = layer_idx self.BLOCK = block_size self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = head_dim self.total_num_heads = num_heads self.hidden_inner_size = hidden_inner_size self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() assert self.total_num_heads % self.tp_size == 0 self.tp_heads = self.total_num_heads // self.tp_size self.qkv_size = self.num_heads * self.head_dim self.tp_hidden = self.head_dim * self.tp_heads self.model_config = model_config self.cache_config = cache_config self.prefix = prefix self.qkv_proj = ColumnParallelLinear( hidden_size, self.hidden_inner_size * 3, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.output_gate = ColumnParallelLinear( hidden_size, self.hidden_inner_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.output_gate", ) self.out_proj = RowParallelLinear( self.hidden_inner_size, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) self.norm = MiniMaxText01RMSNormTP( self.hidden_inner_size, eps=1e-5, ) slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads) if num_hidden_layer <= 1: self.slope_rate = slope_rate * (1 + 1e-5) else: self.slope_rate = slope_rate * ( 1 - layer_idx / (num_hidden_layer - 1) + 1e-5 ) self.tp_slope = self.slope_rate[ self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads ].contiguous() compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @staticmethod def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) return @staticmethod def _build_slope_tensor(n_attention_heads: int): def get_slopes(n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return ( get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) slopes = torch.tensor( get_slopes(n_attention_heads), dtype=torch.float32 ).reshape(n_attention_heads, 1, 1) return slopes def _prefill_and_mix_infer( self, q, k, v, kv_cache, state_indices_tensor, attn_metadata ): hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): if _prefill_idx >= len(attn_metadata.query_start_loc): break if _prefill_idx >= len(state_indices_tensor): break offset = attn_metadata.num_decode_tokens _start = attn_metadata.query_start_loc[offset + _prefill_idx] _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] slot_id = state_indices_tensor[offset + _prefill_idx] qs = q[_start:_end].transpose(0, 1).contiguous() ks = k[_start:_end].transpose(0, 1).contiguous() vs = v[_start:_end].transpose(0, 1).contiguous() slice_layer_cache = kv_cache[slot_id, ...] out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( qs, ks, vs, slice_layer_cache, self.tp_slope, self.BLOCK, layer_idx=self.layer_idx, ) hidden.append(out_slice.contiguous()) if attn_metadata.num_decode_tokens > 0: hidden_decode = self._decode_infer( q, k, v, kv_cache, state_indices_tensor, attn_metadata ) hidden.insert(0, hidden_decode) if not hidden: return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) hidden = torch.concat(hidden, dim=0).contiguous() return hidden def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() slot_id = state_indices_tensor[: attn_metadata.num_decodes] hidden = linear_decode_forward_triton( q, k, v, kv_cache, self.tp_slope, slot_id, 32 ) return hidden def forward( self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor ) -> None: torch.ops.vllm.linear_attention( hidden_states, output, positions, self.prefix, ) def _forward( self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor ) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, LinearAttentionMetadata) num_actual_tokens = ( attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens ) else: num_actual_tokens = hidden_states.shape[0] qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) qkv32 = qkv.to(torch.float32) qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) if attn_metadata is not None: kv_cache = self.kv_cache[forward_context.virtual_engine][0] state_indices_tensor = attn_metadata.state_indices_tensor num_prefills = getattr(attn_metadata, "num_prefills", 0) if num_prefills > 0: num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0) for prefill_idx in range(num_prefills): q_start = attn_metadata.query_start_loc[ num_decode_tokens + prefill_idx ] q_end = attn_metadata.query_start_loc[ num_decode_tokens + prefill_idx + 1 ] query_len = q_end - q_start context_len = ( attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len ) if context_len == 0: block_to_clear = state_indices_tensor[ num_decode_tokens + prefill_idx ] kv_cache[block_to_clear, ...] = 0 decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 if attn_metadata is None: hidden = torch.empty( (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype ) else: if not decode_only: hidden = self._prefill_and_mix_infer( q, k, v, kv_cache, state_indices_tensor, attn_metadata ) else: hidden = self._decode_infer( q, k, v, kv_cache, state_indices_tensor, attn_metadata ) hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) output[:num_actual_tokens], _ = self.out_proj(hidden) def linear_attention( hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] self._forward(hidden_states=hidden_states, output=output, positions=positions) def linear_attention_fake( hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor, layer_name: str, ) -> None: return direct_register_custom_op( op_name="linear_attention", op_func=linear_attention, mutates_args=["output"], fake_impl=linear_attention_fake, )