Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row: tl.int64,
stride_y_row: tl.int64,
stride_z_row: tl.int64,
M: tl.int64, # number of rows in X
N: tl.int64, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
group = tl.program_id(1)
X += row * stride_x_row + group * N
Y += row * stride_y_row + group * N
if HAS_Z:
Z += row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
def rms_norm_gated(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, _, _ = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=True,
)
return y.reshape(x_shape_og)

View File

@@ -0,0 +1,586 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch
from packaging import version
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON, tl, triton
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics(
{
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
is not None
}
)
@triton.heuristics(
{"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None}
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit(do_not_specialize=["N"])
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
x_ptr,
dt_ptr,
dt_bias_ptr,
A_ptr,
B_ptr,
C_ptr,
D_ptr,
z_ptr,
out_ptr,
state_batch_indices_ptr,
dst_state_batch_indices_ptr,
pad_slot_id,
num_accepted_tokens_ptr,
cu_seqlens_ptr,
# Matrix dimensions
N,
nheads,
dim,
dstate,
nheads_ngroups_ratio,
# Strides
stride_state_batch,
stride_state_head,
stride_state_dim,
stride_state_dstate,
stride_x_batch,
stride_x_head,
stride_x_dim,
stride_dt_batch,
stride_dt_head,
stride_dt_dim,
stride_dt_bias_head,
stride_dt_bias_dim,
stride_A_head,
stride_A_dim,
stride_A_dstate,
stride_B_batch,
stride_B_group,
stride_B_dstate,
stride_C_batch,
stride_C_group,
stride_C_dstate,
stride_D_head,
stride_D_dim,
stride_z_batch,
stride_z_head,
stride_z_dim,
stride_out_batch,
stride_out_head,
stride_out_dim,
stride_state_indices_batch,
stride_state_indices_T,
stride_dst_state_indices_batch,
stride_dst_state_indices_T,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_VARLEN: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
if IS_VARLEN:
bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64)
eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64)
seq_len = eos - bos
if seq_len == 0:
return
else:
bos = pid_b
seq_len = 1
state_ptr_base = state_ptr
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
if IS_SPEC_DECODING:
num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64)
init_token_idx = tl.maximum(num_accepted - 1, 0)
else:
init_token_idx = 0
dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch
if not IS_SPEC_DECODING:
dst_state_batch_idx = tl.load(
dst_state_batch_indices_ptr
+ init_token_idx * stride_dst_state_indices_T
).to(tl.int64)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
)
state_batch_indices_ptr += (
pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T
)
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
dst_state_ptr = (
state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head
)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += bos * stride_x_batch + pid_h * stride_x_head
dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += bos * stride_z_batch + pid_h * stride_z_head
out_ptr += bos * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
if not IS_SPEC_DECODING:
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
D_ptrs = D_ptr + offs_m * stride_D_dim
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
for i_t in range(seq_len):
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
if IS_SPEC_DECODING:
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
if token_dst_idx != pad_slot_id:
token_dst_ptrs = (
state_ptr_base
+ token_dst_idx * stride_state_batch
+ pid_h * stride_state_head
+ offs_m[:, None] * stride_state_dim
+ offs_n[None, :] * stride_state_dstate
)
tl.store(
token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask
)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
x_ptr += stride_x_batch
dt_ptr += stride_dt_batch
B_ptr += stride_B_batch
C_ptr += stride_C_batch
out_ptr += stride_out_batch
if HAS_Z:
z_ptr += stride_z_batch
if not IS_SPEC_DECODING:
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
def selective_state_update(
state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
num_accepted_tokens=None,
cu_seqlens=None,
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
num_accepted_tokens: (batch,)
number of accepted tokens from previous verification step,
tells the kernel which initial state to use
cu_seqlens: (batch,)
length per sequence, for variable length in speculative decoding cases
"""
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
if num_accepted_tokens is not None:
assert state_batch_indices is not None and state_batch_indices.dim() == 2
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
if state_batch_indices is not None and state_batch_indices.dim() == 1:
state_batch_indices = state_batch_indices.unsqueeze(1)
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
if cu_seqlens is not None:
N = len(cu_seqlens) - 1
# Only used to verify the shape of
# state_batch_indices and dst_state_batch_indices
max_seqlen = (
state_batch_indices.size(-1) if state_batch_indices is not None else 1
)
else:
N = batch
max_seqlen = 1
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape[0] >= N
assert state_batch_indices.shape[1] >= max_seqlen
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape[0] >= N
assert dst_state_batch_indices.shape[1] >= max_seqlen
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
if num_accepted_tokens is not None:
assert num_accepted_tokens.shape == (N,)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
state_batch_indices_strides = (
(state_batch_indices.stride(0), state_batch_indices.stride(1))
if state_batch_indices is not None
else (0, 0)
)
dst_state_batch_indices_strides = (
(dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1))
if dst_state_batch_indices is not None
else (0, 0)
)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0
and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0
)
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state,
x,
dt,
dt_bias,
A,
B,
C,
D,
z,
out,
state_batch_indices,
dst_state_batch_indices,
pad_slot_id,
num_accepted_tokens,
cu_seqlens,
N,
nheads,
dim,
dstate,
nheads // ngroups,
state.stride(0),
state.stride(1),
state.stride(2),
state.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
dt.stride(0),
dt.stride(1),
dt.stride(2),
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
C.stride(0),
C.stride(1),
C.stride(2),
*(D.stride(0), D.stride(1)) if D is not None else 0,
z_strides[0],
z_strides[1],
z_strides[2],
out.stride(0),
out.stride(1),
out.stride(2),
state_batch_indices_strides[0],
state_batch_indices_strides[1],
dst_state_batch_indices_strides[0],
dst_state_batch_indices_strides[1],
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,
num_warps=num_warps,
)
def selective_scan_fn(
u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID,
block_size=1024,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state indices
- Without APC: (batch,) - single state index per batch item
- With APC: (batch, max_positions) - cache block indices for read/write
Each non-zero value indicates a cache block to load from and/or write to.
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
block_size: int
The block size to align the cached states to
block_idx_first_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the first
cache block to be filled is located.
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the last cache block
to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block
containing the initial state is located.
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
"""
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3 and query_start_loc is None:
B = B.unsqueeze(1)
if B.dim() == 2 and query_start_loc is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and query_start_loc is None:
C = C.unsqueeze(1)
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
)
if z is None:
return delta # output written inplace to delta
else:
return z # output written inplace to z

View File

@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
# ruff: noqa: E501,SIM102
import torch
from vllm.triton_utils import tl, triton
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["chunk_size", "K", "IS_CAUSAL"],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
out_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
seqlen,
chunk_size: tl.constexpr,
K: tl.constexpr,
ngroups: tl.constexpr,
stride_a_seqlen: tl.int64,
stride_a_head: tl.int64,
stride_ak: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_bk: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_outm: tl.int64,
stride_outn: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_ch = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# compute a * b.T
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
).to(dot_dtype)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
& (offs_n[None, :] < chunk_size_limit),
other=0.0,
).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(
out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
)
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
"""
Argument:
a: (seqlen, ngroups, k)
b: (seqlen, ngroups, k)
chunk_size: int
cu_chunk_seq_lens: (nchunks+1,)
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (nchunks, ngroups, chunk_size, chunk_size)
"""
seqlen, ngroups, k = a.shape
assert b.shape == a.shape
if a.stride(-1) != 1 and a.stride(0) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(0) != 1:
b = b.contiguous()
nchunks = len(cu_chunk_seqlens) - 1
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty(
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
)
dot_dtype = (
tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
else (
tl.float16
if a.dtype == torch.float16 or b.dtype == torch.float16
else tl.float32
)
)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
nchunks * ngroups,
)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a_ptr=a,
b_ptr=b,
out_ptr=out,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
seqlen=seqlen,
chunk_size=chunk_size,
K=k,
ngroups=ngroups,
stride_a_seqlen=a.stride(0),
stride_a_head=a.stride(1),
stride_ak=a.stride(2),
stride_b_seqlen=b.stride(0),
stride_b_head=b.stride(1),
stride_bk=b.stride(2),
stride_out_chunk=out.stride(0),
stride_out_head=out.stride(1),
stride_outm=out.stride(-2),
stride_outn=out.stride(-1),
IS_CAUSAL=causal,
dot_dtype=dot_dtype,
)
return out

View File

@@ -0,0 +1,456 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
# ruff: noqa: E501,SIM102
from packaging import version
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
)
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices
cb_ptr,
x_ptr,
z_ptr,
out_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
C_ptr,
states_ptr,
D_ptr,
initstates_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
chunk_size: tl.constexpr,
hdim: tl.constexpr,
dstate: tl.constexpr,
seqlen,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_cb_chunk: tl.int64,
stride_cb_head: tl.int64,
stride_cb_csize_m: tl.int64,
stride_cb_csize_k: tl.constexpr,
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_z_seqlen: tl.int64,
stride_z_head: tl.int64,
stride_z_hdim: tl.constexpr,
stride_out_seqlen: tl.int64,
stride_out_head: tl.int64,
stride_out_hdim: tl.constexpr,
stride_dt_chunk: tl.int64,
stride_dt_head: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_head: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
stride_C_seqlen: tl.int64,
stride_C_head: tl.int64,
stride_C_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
stride_D_head: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += (
chunk_seqlen_start * stride_C_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
)
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
seq_idx_ptr += pid_c * stride_seq_idx_chunk
seq_idx = tl.load(seq_idx_ptr)
seq_idx_prev = tl.load(
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
)
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states_ptr = (
initstates_ptr
+ seq_idx * stride_init_states_batch
+ pid_h * stride_init_states_head
)
prev_states_hdim = stride_init_states_hdim
prev_states_dstate = stride_init_states_dstate
else:
prev_states_ptr = (
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
)
prev_states_hdim = stride_states_hdim
prev_states_dstate = stride_states_dstate
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
)
C_ptrs = C_ptr + (
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
)
scale_m = tl.exp(dA_cs_m)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
# if no init states AND starting a new sequence, we need zeros
prev_states = tl.zeros(
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
)
else:
# otherwise read the previous state
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate - k),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states = tl.zeros(
(BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
)
else:
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate - k)
& (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
prev_states_ptrs += BLOCK_SIZE_K
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
)
x_ptrs = x_ptr + (
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = (
chunk_size_limit
if not IS_CAUSAL
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(
cb_ptrs,
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
tl.float32
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(
x_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
other=0.0,
)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(
x_ptr
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc += x_residual * D
if HAS_Z:
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
)
z = tl.load(
z_ptrs,
mask=(offs_out_m[:, None] < chunk_size_limit)
& (offs_out_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
)
tl.store(
out_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
)
def _chunk_scan_fwd(
cb,
x,
dt,
dA_cumsum,
C,
states,
cu_chunk_seqlens,
out,
seq_idx,
D=None,
z=None,
initial_states=None,
):
assert seq_idx is not None, "this implementation requires seq_idx"
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (seqlen, ngroups, dstate)
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if z is not None:
assert z.shape == x.shape
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
assert states.shape == (nchunks, nheads, headdim, dstate)
assert seq_idx.shape == (nchunks,)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
nchunks,
nheads,
)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
_chunk_scan_fwd_kernel[grid](
cb_ptr=cb,
x_ptr=x,
z_ptr=z,
out_ptr=out,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
seq_idx_ptr=seq_idx,
C_ptr=C,
states_ptr=states,
D_ptr=D,
initstates_ptr=initial_states,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
chunk_size=chunk_size,
hdim=headdim,
dstate=dstate,
seqlen=seqlen,
nheads_ngroups_ratio=nheads // ngroups,
stride_cb_chunk=cb.stride(0),
stride_cb_head=cb.stride(1),
stride_cb_csize_m=cb.stride(2),
stride_cb_csize_k=cb.stride(3),
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_z_seqlen=z_strides[0],
stride_z_head=z_strides[1],
stride_z_hdim=z_strides[2],
stride_out_seqlen=out.stride(0),
stride_out_head=out.stride(1),
stride_out_hdim=out.stride(2),
stride_dt_chunk=dt.stride(1),
stride_dt_head=dt.stride(0),
stride_dt_csize=dt.stride(2),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_seq_idx_chunk=seq_idx.stride(0),
stride_C_seqlen=C.stride(0),
stride_C_head=C.stride(1),
stride_C_dstate=C.stride(2),
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
stride_D_head=D.stride(0) if D is not None else 0,
IS_CAUSAL=True,
HAS_D=D is not None,
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
HAS_Z=z is not None,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return

View File

@@ -0,0 +1,700 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .mamba_ssm import softplus
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_H": 2}),
triton.Config({"BLOCK_SIZE_H": 4}),
triton.Config({"BLOCK_SIZE_H": 8}),
triton.Config({"BLOCK_SIZE_H": 16}),
triton.Config({"BLOCK_SIZE_H": 32}),
triton.Config({"BLOCK_SIZE_H": 64}),
],
key=["chunk_size", "nheads"],
)
@triton.jit
def _chunk_cumsum_fwd_kernel(
# Pointers to matrices
dt_ptr,
A_ptr,
dt_bias_ptr,
dt_out_ptr,
dA_cumsum_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimension
seqlen,
nheads: tl.constexpr,
chunk_size: tl.constexpr,
dt_min: tl.constexpr,
dt_max: tl.constexpr,
# Strides
stride_dt_seqlen: tl.int64,
stride_dt_head: tl.constexpr,
stride_A_head: tl.constexpr,
stride_dt_bias_head: tl.constexpr,
stride_dt_out_head: tl.int64,
stride_dt_out_chunk: tl.int64,
stride_dt_out_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_CHUNK: tl.constexpr,
):
# if dt is long, may cause problems, so use 64 bit
# https://github.com/triton-lang/triton/issues/1058
pid_c = tl.program_id(axis=0).to(tl.int64)
pid_h = tl.program_id(axis=1)
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
dt_ptr += chunk_seqlen_start * stride_dt_seqlen
dt_out_ptr += pid_c * stride_dt_out_chunk
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
dt_ptrs = dt_ptr + (
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
)
A_ptrs = A_ptr + offs_h * stride_A_head
dt_out_ptrs = dt_out_ptr + (
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
)
dA_cs_ptrs = dA_cumsum_ptr + (
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
dt = tl.load(
dt_ptrs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
other=0.0,
).to(tl.float32)
if HAS_DT_BIAS:
dt_bias = tl.load(
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
).to(tl.float32)
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.where(
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
)
tl.store(
dt_out_ptrs,
dt,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
)
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
dA = dt * A[:, None]
dA_cs = tl.cumsum(dA, axis=1)
tl.store(
dA_cs_ptrs,
dA_cs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_fwd_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
states_ptr,
dt_ptr,
dA_cumsum_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
seqlen,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
b_ptr += (
chunk_seqlen_start * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
tl.float32
)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
dt_ptr,
dA_cumsum_ptr,
chunk_states_ptr,
cu_seqlens_ptr,
states_ptr,
initstates_ptr,
# Matrix dimensions
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_chunk_states_chunk: tl.int64,
stride_chunk_states_head: tl.int64,
stride_chunk_states_hdim: tl.int64,
stride_chunk_states_dstate: tl.constexpr,
stride_states_batch: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
pid_c = (end_idx - 1) // chunk_size
b_ptr += (
pid_c * chunk_size * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
chunk_states_ptr += (
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
)
if HAS_INITSTATES:
# if there are init states provided, we differentiate between states (which
# are boundary conditions at a chunk boundary) and initstates (which are boundary
# conditions when a new example in a cont batch starts)
initstates_ptr += pid_h * stride_init_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = end_idx - pid_c * chunk_size
start_idx = tl.load(cu_seqlens_ptr + pid_b)
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim)
& (offs_k[None, :] < chunk_size_limit - k)
& (offs_k[None, :] >= start_idx_cur - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k)
& (offs_n[None, :] < dstate)
& (offs_k[:, None] >= start_idx_cur - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.where(
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
0.0,
)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possibilities
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if (
(start_idx < pid_c * chunk_size) # first chunk
or (HAS_INITSTATES)
):
dA_cs_boundary = 0.0 # default
if not HAS_INITSTATES:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
# - this seems repetitive, buts its to help the compiler
if start_idx < pid_c * chunk_size:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
past_states_ptrs = initstates_ptr + (
pid_b * stride_init_states_batch
+ offs_m[:, None] * stride_init_states_hdim
+ offs_n[None, :] * stride_init_states_dstate
)
# need to adjust the boundary
if start_idx > pid_c * chunk_size:
dA_cs_boundary = tl.load(
dA_cumsum_ptr
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
past_states = tl.load(
past_states_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
scale = tl.exp(dA_cs_last - dA_cs_boundary)
acc += past_states * scale
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
def _chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
):
seqlen, nheads = dt.shape
assert A.shape == (nheads,)
if dt_bias is not None:
assert dt_bias.shape == (nheads,)
nchunks = cu_chunk_seqlens.shape[0] - 1
dt_out = torch.empty(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
dA_cumsum = torch.empty(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
with torch.cuda.device(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt_ptr=dt,
A_ptr=A,
dt_bias_ptr=dt_bias,
dt_out_ptr=dt_out,
dA_cumsum_ptr=dA_cumsum,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
seqlen=seqlen,
nheads=nheads,
chunk_size=chunk_size,
dt_min=dt_limit[0],
dt_max=dt_limit[1],
stride_dt_seqlen=dt.stride(0),
stride_dt_head=dt.stride(1),
stride_A_head=A.stride(0),
stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0,
stride_dt_out_head=dt_out.stride(0),
stride_dt_out_chunk=dt_out.stride(1),
stride_dt_out_csize=dt_out.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
DT_SOFTPLUS=dt_softplus,
HAS_DT_BIAS=dt_bias is not None,
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
)
return dA_cumsum, dt_out
def _chunk_state_fwd(
B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True
):
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
if states is not None:
assert states.shape == (nchunks, nheads, headdim, dstate)
else:
states_dtype = torch.float32 if states_in_fp32 else B.dtype
states = torch.empty(
(nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype
)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
nchunks,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_fwd_kernel[grid](
x_ptr=x,
b_ptr=B,
states_ptr=states,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
seqlen=seqlen,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
)
return states
def chunk_state_varlen(
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
):
total_seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
assert nheads % ngroups == 0
assert B.shape == (total_seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
if initial_states is not None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
states = torch.empty(
batch,
nheads,
headdim,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device,
)
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
batch,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x_ptr=x,
b_ptr=B,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
chunk_states_ptr=chunk_states,
cu_seqlens_ptr=cu_seqlens,
states_ptr=states,
initstates_ptr=initial_states,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_chunk_states_chunk=chunk_states.stride(0),
stride_chunk_states_head=chunk_states.stride(1),
stride_chunk_states_hdim=chunk_states.stride(2),
stride_chunk_states_dstate=chunk_states.stride(3),
stride_states_batch=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
HAS_INITSTATES=initial_states is not None,
)
return states

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
# ruff: noqa: E501
import torch
from einops import rearrange
from packaging import version
from vllm.triton_utils import triton
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
def is_int_pow_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
def _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_seqlens=None,
cu_chunk_seqlens=None,
last_chunk_indices=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
seqlen, nheads, headdim = x.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (seqlen, nheads)
assert A.shape == (nheads,)
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if seq_idx is not None:
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if (
x.stride(-1) != 1 and x.stride(0) != 1
): # Either M or K dimension should be contiguous
x = x.contiguous()
if (
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
): # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
if initial_states is not None:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix separately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states and
# ii) seq_idx to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum, # (nheads, nchunks, chunk_size)
cu_chunk_seqlens,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None
else None, # (batch, nheads, headdim*dstate)
seq_idx=seq_idx,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
)
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
# - if initial states are provided, then states information will be
# augmented with initial_states.
# - to do this properly, we need to account for example changes in
# the continuous batch, therefore we introduce pseudo chunks, which is
# a chunk that is split up each time an example changes.
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
_chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
cu_chunk_seqlens,
out, # in-place update
seq_idx,
D=D,
z=z,
initial_states=initial_states,
)
if return_intermediate_states:
return states
else:
return states[last_chunk_indices]
def mamba_chunk_scan_combined_varlen(
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_intermediate_states=False,
state_dtype=None,
):
"""
Argument:
x: (seqlen, nheads, headdim)
dt: (seqlen, nheads)
A: (nheads)
B: (seqlen, ngroups, dstate)
C: (seqlen, ngroups, dstate)
chunk_size: int
cu_seqlens: (batch + 1,)
cu_chunk_seqlens: (nchunks + 1,)
last_chunk_indices: (batch,)
seq_idx: (nchunks,)
out: (seqlen, nheads, headdim) preallocated output tensor
D: (nheads, headdim) or (nheads,)
z: (seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
dt_softplus: Whether to apply softplus to dt
out: (seqlen, nheads, headdim) preallocated output tensor
state_dtype: The data type of the ssm state
Return:
varlen_states: (batch, nheads, headdim, dstate)
"""
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
assert seq_idx is not None
varlen_states = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
return_intermediate_states=return_intermediate_states,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
cu_chunk_seqlens=cu_chunk_seqlens,
last_chunk_indices=last_chunk_indices,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
state_dtype=state_dtype,
)
return varlen_states

View File

@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
],
key=["dim"],
)
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
dim: tl.constexpr,
nchunks,
seqlen,
chunk_size: tl.constexpr,
# Strides
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_dim: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_out_dim: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_initstates_batch: tl.int64,
stride_initstates_head: tl.int64,
stride_initstates_dim: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_h = tl.program_id(axis=1)
pid_m = tl.program_id(axis=0)
states_ptr += pid_h * stride_states_head
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
if HAS_INITSTATES:
initstates_ptrs = (
initstates_ptr
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
else:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
# we have started a new sequence
if prev_seq_idx != seq_idx:
if HAS_INITSTATES:
initstates_ptrs = (
initstates_ptr
+ seq_idx * stride_initstates_batch
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
tl.float32
)
else:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = seq_idx
states = tl.exp(dA_cs) * states + new_states
tl.store(out_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
def _state_passing_fwd(
states,
dA_cumsum,
cu_chunk_seqlens,
seq_idx,
initial_states=None,
out_dtype=None,
):
nchunks, nheads, dim = states.shape
chunk_size = dA_cumsum.shape[-1]
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
seqlen = seq_idx.shape[-1]
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
initial_states_strides = (
(initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
if initial_states is not None
else (0, 0, 0)
)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states_ptr=states,
out_ptr=out,
dA_cs_ptr=dA_cumsum,
initstates_ptr=initial_states,
seq_idx_ptr=seq_idx,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
dim=dim,
nchunks=nchunks,
seqlen=seqlen if seq_idx is not None else 0,
chunk_size=chunk_size if seq_idx is not None else 0,
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_dim=states.stride(2),
stride_out_chunk=out.stride(0),
stride_out_head=out.stride(1),
stride_out_dim=out.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_initstates_batch=initial_states_strides[0],
stride_initstates_head=initial_states_strides[1],
stride_initstates_dim=initial_states_strides[2],
stride_seq_idx_chunk=seq_idx.stride(0),
HAS_INITSTATES=initial_states is not None,
)
return out