# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2024, Tri Dao, Albert Gu. # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py # ruff: noqa: E501 import torch from einops import rearrange from packaging import version from vllm.triton_utils import triton from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") def is_int_pow_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 def _mamba_chunk_scan_combined_fwd( x, dt, A, B, C, chunk_size, out, D=None, z=None, dt_bias=None, initial_states=None, return_intermediate_states=False, seq_idx=None, cu_seqlens=None, cu_chunk_seqlens=None, last_chunk_indices=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, ): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" seqlen, nheads, headdim = x.shape _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (seqlen, ngroups, dstate) assert dt.shape == (seqlen, nheads) assert A.shape == (nheads,) assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) if seq_idx is not None: assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if ( x.stride(-1) != 1 and x.stride(0) != 1 ): # Either M or K dimension should be contiguous x = x.contiguous() if ( z is not None and z.stride(-1) != 1 and z.stride(0) != 1 ): # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" if initial_states is not None: assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ # which has a minimal implementation to understand the below operations # - as explained by the blog, mamba is a special case of causal attention # - the idea is to chunk the attention matrix and compute each # submatrix separately using different optimizations. # - see the blog and paper for a visualization of the submatrices # which we refer to in the comments below # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd( dt, A, chunk_size, cu_chunk_seqlens, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ) # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd( B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True ) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states and # ii) seq_idx to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, # (nheads, nchunks, chunk_size) cu_chunk_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, # (batch, nheads, headdim*dstate) seq_idx=seq_idx, out_dtype=state_dtype if state_dtype is not None else C.dtype, ) states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32) # 5. Scan and compute the diagonal blocks, taking into # account past causal states. # - if initial states are provided, then states information will be # augmented with initial_states. # - to do this properly, we need to account for example changes in # the continuous batch, therefore we introduce pseudo chunks, which is # a chunk that is split up each time an example changes. # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had # a seq_idx change, in which case we take states information from # init_states. _chunk_scan_fwd( CB, x, dt, dA_cumsum, C, states, cu_chunk_seqlens, out, # in-place update seq_idx, D=D, z=z, initial_states=initial_states, ) if return_intermediate_states: return states else: return states[last_chunk_indices] def mamba_chunk_scan_combined_varlen( x, dt, A, B, C, chunk_size, cu_seqlens, cu_chunk_seqlens, last_chunk_indices, seq_idx, out, D=None, z=None, dt_bias=None, initial_states=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_intermediate_states=False, state_dtype=None, ): """ Argument: x: (seqlen, nheads, headdim) dt: (seqlen, nheads) A: (nheads) B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int cu_seqlens: (batch + 1,) cu_chunk_seqlens: (nchunks + 1,) last_chunk_indices: (batch,) seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt out: (seqlen, nheads, headdim) preallocated output tensor state_dtype: The data type of the ssm state Return: varlen_states: (batch, nheads, headdim, dstate) """ assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input" assert seq_idx is not None varlen_states = _mamba_chunk_scan_combined_fwd( x, dt, A, B, C, chunk_size, out, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, return_intermediate_states=return_intermediate_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, last_chunk_indices=last_chunk_indices, dt_softplus=dt_softplus, dt_limit=dt_limit, state_dtype=state_dtype, ) return varlen_states