231 lines
7.2 KiB
Python
231 lines
7.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
|
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
|
|
|
# ruff: noqa: E501
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
from packaging import version
|
|
|
|
from vllm.triton_utils import triton
|
|
|
|
from .ssd_bmm import _bmm_chunk_fwd
|
|
from .ssd_chunk_scan import _chunk_scan_fwd
|
|
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
|
from .ssd_state_passing import _state_passing_fwd
|
|
|
|
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
|
|
|
|
|
def is_int_pow_2(n):
|
|
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
|
|
|
|
|
def _mamba_chunk_scan_combined_fwd(
|
|
x,
|
|
dt,
|
|
A,
|
|
B,
|
|
C,
|
|
chunk_size,
|
|
out,
|
|
D=None,
|
|
z=None,
|
|
dt_bias=None,
|
|
initial_states=None,
|
|
return_intermediate_states=False,
|
|
seq_idx=None,
|
|
cu_seqlens=None,
|
|
cu_chunk_seqlens=None,
|
|
last_chunk_indices=None,
|
|
dt_softplus=False,
|
|
dt_limit=(0.0, float("inf")),
|
|
state_dtype=None,
|
|
):
|
|
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
|
seqlen, nheads, headdim = x.shape
|
|
_, ngroups, dstate = B.shape
|
|
assert nheads % ngroups == 0
|
|
assert B.shape == (seqlen, ngroups, dstate)
|
|
assert dt.shape == (seqlen, nheads)
|
|
assert A.shape == (nheads,)
|
|
assert C.shape == B.shape
|
|
if z is not None:
|
|
assert z.shape == x.shape
|
|
if D is not None:
|
|
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
|
if seq_idx is not None:
|
|
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
|
|
if B.stride(-1) != 1:
|
|
B = B.contiguous()
|
|
if C.stride(-1) != 1:
|
|
C = C.contiguous()
|
|
if (
|
|
x.stride(-1) != 1 and x.stride(0) != 1
|
|
): # Either M or K dimension should be contiguous
|
|
x = x.contiguous()
|
|
if (
|
|
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
|
|
): # Either M or K dimension should be contiguous
|
|
z = z.contiguous()
|
|
if D is not None and D.stride(-1) != 1:
|
|
D = D.contiguous()
|
|
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
|
|
|
if initial_states is not None:
|
|
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
|
|
|
|
# This function executes 5 sub-functions for computing mamba
|
|
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
|
# which has a minimal implementation to understand the below operations
|
|
# - as explained by the blog, mamba is a special case of causal attention
|
|
# - the idea is to chunk the attention matrix and compute each
|
|
# submatrix separately using different optimizations.
|
|
# - see the blog and paper for a visualization of the submatrices
|
|
# which we refer to in the comments below
|
|
|
|
# 1. Compute chunked cumsum of A * dt
|
|
# - here dt may go through a softplus activation
|
|
dA_cumsum, dt = _chunk_cumsum_fwd(
|
|
dt,
|
|
A,
|
|
chunk_size,
|
|
cu_chunk_seqlens,
|
|
dt_bias=dt_bias,
|
|
dt_softplus=dt_softplus,
|
|
dt_limit=dt_limit,
|
|
)
|
|
|
|
# 2. Compute the state for each intra-chunk
|
|
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
|
states = _chunk_state_fwd(
|
|
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
|
|
)
|
|
|
|
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
|
# (middle term of factorization of off-diag blocks; A terms)
|
|
# - for handling chunked prefill, this requires i) initial_states and
|
|
# ii) seq_idx to be all specified.
|
|
# - When a new seq_idx is detected, we will stop passing the prev_state
|
|
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
|
states = _state_passing_fwd(
|
|
rearrange(states, "... p n -> ... (p n)"),
|
|
dA_cumsum, # (nheads, nchunks, chunk_size)
|
|
cu_chunk_seqlens,
|
|
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
|
if initial_states is not None
|
|
else None, # (batch, nheads, headdim*dstate)
|
|
seq_idx=seq_idx,
|
|
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
|
)
|
|
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
|
|
|
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
|
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
|
|
|
|
# 5. Scan and compute the diagonal blocks, taking into
|
|
# account past causal states.
|
|
# - if initial states are provided, then states information will be
|
|
# augmented with initial_states.
|
|
# - to do this properly, we need to account for example changes in
|
|
# the continuous batch, therefore we introduce pseudo chunks, which is
|
|
# a chunk that is split up each time an example changes.
|
|
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
|
# a seq_idx change, in which case we take states information from
|
|
# init_states.
|
|
_chunk_scan_fwd(
|
|
CB,
|
|
x,
|
|
dt,
|
|
dA_cumsum,
|
|
C,
|
|
states,
|
|
cu_chunk_seqlens,
|
|
out, # in-place update
|
|
seq_idx,
|
|
D=D,
|
|
z=z,
|
|
initial_states=initial_states,
|
|
)
|
|
|
|
if return_intermediate_states:
|
|
return states
|
|
else:
|
|
return states[last_chunk_indices]
|
|
|
|
|
|
def mamba_chunk_scan_combined_varlen(
|
|
x,
|
|
dt,
|
|
A,
|
|
B,
|
|
C,
|
|
chunk_size,
|
|
cu_seqlens,
|
|
cu_chunk_seqlens,
|
|
last_chunk_indices,
|
|
seq_idx,
|
|
out,
|
|
D=None,
|
|
z=None,
|
|
dt_bias=None,
|
|
initial_states=None,
|
|
dt_softplus=False,
|
|
dt_limit=(0.0, float("inf")),
|
|
return_intermediate_states=False,
|
|
state_dtype=None,
|
|
):
|
|
"""
|
|
Argument:
|
|
x: (seqlen, nheads, headdim)
|
|
dt: (seqlen, nheads)
|
|
A: (nheads)
|
|
B: (seqlen, ngroups, dstate)
|
|
C: (seqlen, ngroups, dstate)
|
|
chunk_size: int
|
|
cu_seqlens: (batch + 1,)
|
|
cu_chunk_seqlens: (nchunks + 1,)
|
|
last_chunk_indices: (batch,)
|
|
seq_idx: (nchunks,)
|
|
out: (seqlen, nheads, headdim) preallocated output tensor
|
|
D: (nheads, headdim) or (nheads,)
|
|
z: (seqlen, nheads, headdim)
|
|
dt_bias: (nheads,)
|
|
initial_states: (batch, nheads, headdim, dstate)
|
|
dt_softplus: Whether to apply softplus to dt
|
|
out: (seqlen, nheads, headdim) preallocated output tensor
|
|
state_dtype: The data type of the ssm state
|
|
Return:
|
|
varlen_states: (batch, nheads, headdim, dstate)
|
|
"""
|
|
|
|
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
|
assert seq_idx is not None
|
|
|
|
varlen_states = _mamba_chunk_scan_combined_fwd(
|
|
x,
|
|
dt,
|
|
A,
|
|
B,
|
|
C,
|
|
chunk_size,
|
|
out,
|
|
D=D,
|
|
z=z,
|
|
dt_bias=dt_bias,
|
|
initial_states=initial_states,
|
|
return_intermediate_states=return_intermediate_states,
|
|
seq_idx=seq_idx,
|
|
cu_seqlens=cu_seqlens,
|
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
|
last_chunk_indices=last_chunk_indices,
|
|
dt_softplus=dt_softplus,
|
|
dt_limit=dt_limit,
|
|
state_dtype=state_dtype,
|
|
)
|
|
|
|
return varlen_states
|