# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from einops import rearrange from torch import nn from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .fla.ops.kda import ( FusedRMSNormGated, chunk_kda, fused_kda_gate, fused_recurrent_kda, ) from .linear import ( ColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) from .mamba.abstract import MambaBase from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .quantization.base_config import QuantizationConfig logger = init_logger(__name__) def kda_attention( q_proj_states: torch.Tensor, k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] self._forward( q_proj_states=q_proj_states, k_proj_states=k_proj_states, v_proj_states=v_proj_states, g1=g1, beta=beta, core_attn_out=core_attn_out, ) def kda_attention_fake( q_proj_states: torch.Tensor, k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, ) -> None: return direct_register_custom_op( op_name="kda_attention", op_func=kda_attention, mutates_args=["core_attn_out"], fake_impl=kda_attention_fake, ) class KimiDeltaAttention(nn.Module, MambaBase): @property def mamba_type(self) -> str: return "gdn_attention" def get_state_dtype( self, ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: if self.model_config is None or self.cache_config is None: raise ValueError("model_config and cache_config must be set") return MambaStateDtypeCalculator.kda_state_dtype( self.model_config.dtype, self.cache_config.mamba_cache_dtype ) def get_state_shape( self, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.kda_state_shape( self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size ) def __init__( self, layer_idx: int, hidden_size: int, quant_config: QuantizationConfig | None = None, cache_config: CacheConfig | None = None, model_config: ModelConfig | None = None, rms_norm_eps: float = 1e-5, prefix: str = "", **kwargs, ) -> None: super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.hidden_size = hidden_size self.model_config = model_config self.cache_config = cache_config if model_config is None: raise ValueError("model_config must be provided") kda_config = model_config.linear_attn_config self.head_dim = kda_config["head_dim"] self.num_heads = kda_config["num_heads"] self.layer_idx = layer_idx self.prefix = prefix assert self.num_heads % self.tp_size == 0 self.local_num_heads = divide(self.num_heads, self.tp_size) projection_size = self.head_dim * self.num_heads self.conv_size = kda_config["short_conv_kernel_size"] self.q_proj = ColumnParallelLinear( self.hidden_size, projection_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj", ) self.k_proj = ColumnParallelLinear( self.hidden_size, projection_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.k_proj", ) self.v_proj = ColumnParallelLinear( self.hidden_size, projection_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.v_proj", ) self.f_a_proj = ReplicatedLinear( self.hidden_size, self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.f_a_proj", ) self.f_b_proj = ColumnParallelLinear( self.head_dim, projection_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.f_b_proj", ) self.dt_bias = nn.Parameter( torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32) ) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.b_proj = ColumnParallelLinear( self.hidden_size, self.num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.b_proj", ) self.q_conv1d = ColumnParallelLinear( input_size=self.conv_size, output_size=projection_size, bias=False, params_dtype=torch.float32, prefix=f"{prefix}.q_conv1d", ) self.k_conv1d = ColumnParallelLinear( input_size=self.conv_size, output_size=projection_size, bias=False, params_dtype=torch.float32, prefix=f"{prefix}.k_conv1d", ) self.v_conv1d = ColumnParallelLinear( input_size=self.conv_size, output_size=projection_size, bias=False, params_dtype=torch.float32, prefix=f"{prefix}.v_conv1d", ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in # `ColumnParallelLinear` and `set_weight_attrs` # doesn't allow to override it self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1) self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1) self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1) self.A_log = nn.Parameter( torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32) ) set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)}) self.g_a_proj = ReplicatedLinear( self.hidden_size, self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.g_a_proj", ) self.g_b_proj = ColumnParallelLinear( self.head_dim, projection_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.g_b_proj", ) self.o_norm = FusedRMSNormGated( self.head_dim, eps=rms_norm_eps, activation="sigmoid" ) self.o_proj = RowParallelLinear( projection_size, self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor, output: torch.Tensor, ) -> None: num_tokens = hidden_states.size(0) q = self.q_proj(hidden_states)[0] k = self.k_proj(hidden_states)[0] v = self.v_proj(hidden_states)[0] beta = self.b_proj(hidden_states)[0].float().sigmoid() g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias) beta = beta.unsqueeze(0) g1 = g1.unsqueeze(0) g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim) core_attn_out = torch.zeros( (1, num_tokens, self.local_num_heads, self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) torch.ops.vllm.kda_attention( q, k, v, g1, beta, core_attn_out, self.prefix, ) core_attn_out = self.o_norm(core_attn_out, g2) core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") output[:] = self.o_proj(core_attn_out)[0] def _forward( self, q_proj_states: torch.Tensor, k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, ) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: # # V1 profile run return assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 num_actual_tokens = attn_metadata.num_actual_tokens constant_caches = self.kv_cache[forward_context.virtual_engine] q_proj_states = q_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens] v_proj_states = v_proj_states[:num_actual_tokens] g1 = g1[:num_actual_tokens] beta = beta[:num_actual_tokens] (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches # deal with strides conv_state_q = conv_state_q.transpose(-1, -2) conv_state_k = conv_state_k.transpose(-1, -2) conv_state_v = conv_state_v.transpose(-1, -2) q_conv_weights = self.q_conv1d.weight.view( self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) ) k_conv_weights = self.k_conv1d.weight.view( self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2) ) v_conv_weights = self.v_conv1d.weight.view( self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2) ) if attn_metadata.num_prefills > 0: q_proj_states = q_proj_states.transpose(0, 1) k_proj_states = k_proj_states.transpose(0, 1) v_proj_states = v_proj_states.transpose(0, 1) q = causal_conv1d_fn( q_proj_states, q_conv_weights, self.q_conv1d.bias, activation="silu", conv_states=conv_state_q, has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) k = causal_conv1d_fn( k_proj_states, k_conv_weights, self.k_conv1d.bias, activation="silu", conv_states=conv_state_k, has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) v = causal_conv1d_fn( v_proj_states, v_conv_weights, self.v_conv1d.bias, activation="silu", conv_states=conv_state_v, has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) else: decode_conv_indices = non_spec_state_indices_tensor[ : attn_metadata.num_actual_tokens ] q = causal_conv1d_update( q_proj_states, conv_state_q, q_conv_weights, self.q_conv1d.bias, activation="silu", conv_state_indices=decode_conv_indices, validate_data=True, ) k = causal_conv1d_update( k_proj_states, conv_state_k, k_conv_weights, self.k_conv1d.bias, activation="silu", conv_state_indices=decode_conv_indices, validate_data=True, ) v = causal_conv1d_update( v_proj_states, conv_state_v, v_conv_weights, self.v_conv1d.bias, activation="silu", conv_state_indices=decode_conv_indices, validate_data=True, ) q, k, v = map( lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) ) if attn_metadata.num_prefills > 0: zero_idx = non_spec_state_indices_tensor[~has_initial_state] recurrent_state[zero_idx] = 0 initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous() ( core_attn_out_non_spec, last_recurrent_state, ) = chunk_kda( q=q, k=k, v=v, g=g1, beta=beta, initial_state=initial_state, output_final_state=True, use_qk_l2norm_in_kernel=True, cu_seqlens=non_spec_query_start_loc, ) # Init cache recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state else: ( core_attn_out_non_spec, last_recurrent_state, ) = fused_recurrent_kda( q=q, k=k, v=v, g=g1, beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], ssm_state_indices=non_spec_state_indices_tensor, ) core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[ 0, :num_actual_tokens ]