v1.0
This commit is contained in:
0
model_executor/layers/mamba/ops/__init__.py
Normal file
0
model_executor/layers/mamba/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1240
model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
1240
model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
File diff suppressed because it is too large
Load Diff
172
model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
172
model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row: tl.int64,
|
||||
stride_y_row: tl.int64,
|
||||
stride_z_row: tl.int64,
|
||||
M: tl.int64, # number of rows in X
|
||||
N: tl.int64, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
return y.reshape(x_shape_og)
|
||||
478
model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
478
model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
||||
is not None
|
||||
}
|
||||
)
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
||||
)
|
||||
@triton.jit
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
A_ptr,
|
||||
B_ptr,
|
||||
C_ptr,
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
dst_state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch,
|
||||
stride_state_head,
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_bias_dim,
|
||||
stride_A_head,
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
dst_state_batch_indices_ptr += pid_b
|
||||
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
|
||||
dst_state_ptr = state_ptr + (
|
||||
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
else:
|
||||
dst_state_ptr = (
|
||||
state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
dst_state_ptrs = dst_state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (
|
||||
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
||||
)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(
|
||||
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
tl.store(dst_state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch,)
|
||||
if dst_state_batch_indices is not None:
|
||||
assert dst_state_batch_indices.shape == (batch,)
|
||||
else:
|
||||
# revert to the default behavior of in-place state updates
|
||||
dst_state_batch_indices = state_batch_indices
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = (
|
||||
(32, 4)
|
||||
if dstate <= 16
|
||||
else (
|
||||
(16, 4)
|
||||
if dstate <= 32
|
||||
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
||||
)
|
||||
)
|
||||
tie_hdim = (
|
||||
A.stride(-1) == 0
|
||||
and A.stride(-2) == 0
|
||||
and dt.stride(-1) == 0
|
||||
and dt_bias.stride(-1) == 0
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
dst_state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads // ngroups,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
def selective_scan_fn(
|
||||
u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=1024,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
dt_bias: (dim,) or (dim)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended with 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state indices
|
||||
- Without APC: (batch,) - single state index per batch item
|
||||
- With APC: (batch, max_positions) - cache block indices for read/write
|
||||
Each non-zero value indicates a cache block to load from and/or write to.
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
block_size: int
|
||||
The block size to align the cached states to
|
||||
block_idx_first_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the first
|
||||
cache block to be filled is located.
|
||||
block_idx_last_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the last cache block
|
||||
to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the cache block
|
||||
containing the initial state is located.
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3 and query_start_loc is None:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 2 and query_start_loc is not None:
|
||||
B = B.unsqueeze(0)
|
||||
if C.dim() == 3 and query_start_loc is None:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 2 and query_start_loc is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
ops.selective_scan_fwd(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
else:
|
||||
return z # output written inplace to z
|
||||
211
model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
211
model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["chunk_size", "K", "IS_CAUSAL"],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
ngroups: tl.constexpr,
|
||||
stride_a_seqlen: tl.int64,
|
||||
stride_a_head: tl.int64,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_outm: tl.int64,
|
||||
stride_outn: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
|
||||
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# compute a * b.T
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
|
||||
& (offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (seqlen, ngroups, k)
|
||||
b: (seqlen, ngroups, k)
|
||||
chunk_size: int
|
||||
cu_chunk_seq_lens: (nchunks+1,)
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||
b = b.contiguous()
|
||||
|
||||
nchunks = len(cu_chunk_seqlens) - 1
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty(
|
||||
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
|
||||
)
|
||||
dot_dtype = (
|
||||
tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
|
||||
else (
|
||||
tl.float16
|
||||
if a.dtype == torch.float16 or b.dtype == torch.float16
|
||||
else tl.float32
|
||||
)
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
|
||||
nchunks * ngroups,
|
||||
)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
out_ptr=out,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
seqlen=seqlen,
|
||||
chunk_size=chunk_size,
|
||||
K=k,
|
||||
ngroups=ngroups,
|
||||
stride_a_seqlen=a.stride(0),
|
||||
stride_a_head=a.stride(1),
|
||||
stride_ak=a.stride(2),
|
||||
stride_b_seqlen=b.stride(0),
|
||||
stride_b_head=b.stride(1),
|
||||
stride_bk=b.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_outm=out.stride(-2),
|
||||
stride_outn=out.stride(-1),
|
||||
IS_CAUSAL=causal,
|
||||
dot_dtype=dot_dtype,
|
||||
)
|
||||
return out
|
||||
456
model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
456
model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
cb_ptr,
|
||||
x_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
C_ptr,
|
||||
states_ptr,
|
||||
D_ptr,
|
||||
initstates_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size: tl.constexpr,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_cb_chunk: tl.int64,
|
||||
stride_cb_head: tl.int64,
|
||||
stride_cb_csize_m: tl.int64,
|
||||
stride_cb_csize_k: tl.constexpr,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_z_seqlen: tl.int64,
|
||||
stride_z_head: tl.int64,
|
||||
stride_z_hdim: tl.constexpr,
|
||||
stride_out_seqlen: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_hdim: tl.constexpr,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
stride_C_seqlen: tl.int64,
|
||||
stride_C_head: tl.int64,
|
||||
stride_C_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
stride_D_head: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += (
|
||||
chunk_seqlen_start * stride_C_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
)
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
|
||||
seq_idx_ptr += pid_c * stride_seq_idx_chunk
|
||||
seq_idx = tl.load(seq_idx_ptr)
|
||||
seq_idx_prev = tl.load(
|
||||
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
|
||||
)
|
||||
|
||||
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states_ptr = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_init_states_batch
|
||||
+ pid_h * stride_init_states_head
|
||||
)
|
||||
prev_states_hdim = stride_init_states_hdim
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
else:
|
||||
prev_states_ptr = (
|
||||
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
|
||||
)
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(
|
||||
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
||||
)
|
||||
C_ptrs = C_ptr + (
|
||||
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
|
||||
)
|
||||
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
# if no init states AND starting a new sequence, we need zeros
|
||||
prev_states = tl.zeros(
|
||||
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
|
||||
)
|
||||
else:
|
||||
# otherwise read the previous state
|
||||
prev_states_ptrs = (
|
||||
prev_states_ptr
|
||||
+ offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
|
||||
else:
|
||||
prev_states_ptrs = (
|
||||
prev_states_ptr
|
||||
+ offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0,
|
||||
)
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states = tl.zeros(
|
||||
(BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
|
||||
)
|
||||
else:
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k)
|
||||
& (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
cb_ptrs = cb_ptr + (
|
||||
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
||||
)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = (
|
||||
chunk_size_limit
|
||||
if not IS_CAUSAL
|
||||
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
)
|
||||
for k in range(0, K_MAX, BLOCK_SIZE_K):
|
||||
cb = tl.load(
|
||||
cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
mask = offs_m[:, None] >= k + offs_k[None, :]
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(x_ptr.dtype.element_ty)
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
acc += tl.dot(cb, x)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
if HAS_D:
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(
|
||||
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
||||
).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
x_residual = tl.load(
|
||||
x_ptr
|
||||
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (
|
||||
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
|
||||
)
|
||||
z = tl.load(
|
||||
z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit)
|
||||
& (offs_out_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (
|
||||
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
|
||||
)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
|
||||
)
|
||||
|
||||
|
||||
def _chunk_scan_fwd(
|
||||
cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
cu_chunk_seqlens,
|
||||
out,
|
||||
seq_idx,
|
||||
D=None,
|
||||
z=None,
|
||||
initial_states=None,
|
||||
):
|
||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = C.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert C.shape == (seqlen, ngroups, dstate)
|
||||
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (nchunks,)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
||||
nchunks,
|
||||
nheads,
|
||||
)
|
||||
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb_ptr=cb,
|
||||
x_ptr=x,
|
||||
z_ptr=z,
|
||||
out_ptr=out,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
C_ptr=C,
|
||||
states_ptr=states,
|
||||
D_ptr=D,
|
||||
initstates_ptr=initial_states,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_cb_chunk=cb.stride(0),
|
||||
stride_cb_head=cb.stride(1),
|
||||
stride_cb_csize_m=cb.stride(2),
|
||||
stride_cb_csize_k=cb.stride(3),
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_z_seqlen=z_strides[0],
|
||||
stride_z_head=z_strides[1],
|
||||
stride_z_hdim=z_strides[2],
|
||||
stride_out_seqlen=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_hdim=out.stride(2),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
stride_C_seqlen=C.stride(0),
|
||||
stride_C_head=C.stride(1),
|
||||
stride_C_dstate=C.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
stride_D_head=D.stride(0) if D is not None else 0,
|
||||
IS_CAUSAL=True,
|
||||
HAS_D=D is not None,
|
||||
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
|
||||
HAS_Z=z is not None,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return
|
||||
700
model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
700
model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
@@ -0,0 +1,700 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE_H": 2}),
|
||||
triton.Config({"BLOCK_SIZE_H": 4}),
|
||||
triton.Config({"BLOCK_SIZE_H": 8}),
|
||||
triton.Config({"BLOCK_SIZE_H": 16}),
|
||||
triton.Config({"BLOCK_SIZE_H": 32}),
|
||||
triton.Config({"BLOCK_SIZE_H": 64}),
|
||||
],
|
||||
key=["chunk_size", "nheads"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
dt_ptr,
|
||||
A_ptr,
|
||||
dt_bias_ptr,
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimension
|
||||
seqlen,
|
||||
nheads: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
dt_min: tl.constexpr,
|
||||
dt_max: tl.constexpr,
|
||||
# Strides
|
||||
stride_dt_seqlen: tl.int64,
|
||||
stride_dt_head: tl.constexpr,
|
||||
stride_A_head: tl.constexpr,
|
||||
stride_dt_bias_head: tl.constexpr,
|
||||
stride_dt_out_head: tl.int64,
|
||||
stride_dt_out_chunk: tl.int64,
|
||||
stride_dt_out_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr,
|
||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
# if dt is long, may cause problems, so use 64 bit
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
|
||||
dt_ptr += chunk_seqlen_start * stride_dt_seqlen
|
||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (
|
||||
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
||||
)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (
|
||||
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
||||
)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (
|
||||
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
||||
)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
dt = tl.load(
|
||||
dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(
|
||||
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
||||
).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
|
||||
dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
||||
)
|
||||
tl.store(
|
||||
dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(
|
||||
dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
states_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
b_ptr += (
|
||||
chunk_seqlen_start * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
||||
tl.float32
|
||||
)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
chunk_states_ptr,
|
||||
cu_seqlens_ptr,
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_chunk_states_chunk: tl.int64,
|
||||
stride_chunk_states_head: tl.int64,
|
||||
stride_chunk_states_hdim: tl.int64,
|
||||
stride_chunk_states_dstate: tl.constexpr,
|
||||
stride_states_batch: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += (
|
||||
pid_c * chunk_size * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += (
|
||||
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
# are boundary conditions at a chunk boundary) and initstates (which are boundary
|
||||
# conditions when a new example in a cont batch starts)
|
||||
initstates_ptr += pid_h * stride_init_states_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(
|
||||
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
||||
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim)
|
||||
& (offs_k[None, :] < chunk_size_limit - k)
|
||||
& (offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k)
|
||||
& (offs_n[None, :] < dstate)
|
||||
& (offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
|
||||
0.0,
|
||||
)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if (
|
||||
(start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)
|
||||
):
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch
|
||||
+ offs_m[:, None] * stride_init_states_hdim
|
||||
+ offs_n[None, :] * stride_init_states_dstate
|
||||
)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cumsum_ptr
|
||||
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
|
||||
past_states = tl.load(
|
||||
past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
):
|
||||
seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads,)
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads,)
|
||||
nchunks = cu_chunk_seqlens.shape[0] - 1
|
||||
dt_out = torch.empty(
|
||||
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
dA_cumsum = torch.empty(
|
||||
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt_ptr=dt,
|
||||
A_ptr=A,
|
||||
dt_bias_ptr=dt_bias,
|
||||
dt_out_ptr=dt_out,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
seqlen=seqlen,
|
||||
nheads=nheads,
|
||||
chunk_size=chunk_size,
|
||||
dt_min=dt_limit[0],
|
||||
dt_max=dt_limit[1],
|
||||
stride_dt_seqlen=dt.stride(0),
|
||||
stride_dt_head=dt.stride(1),
|
||||
stride_A_head=A.stride(0),
|
||||
stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
stride_dt_out_head=dt_out.stride(0),
|
||||
stride_dt_out_chunk=dt_out.stride(1),
|
||||
stride_dt_out_csize=dt_out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
DT_SOFTPLUS=dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True
|
||||
):
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
|
||||
if states is not None:
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty(
|
||||
(nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
nchunks,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
states_ptr=states,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(
|
||||
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
|
||||
):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
batch = cu_seqlens.shape[0] - 1
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (total_seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(
|
||||
batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device,
|
||||
)
|
||||
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
batch,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
chunk_states_ptr=chunk_states,
|
||||
cu_seqlens_ptr=cu_seqlens,
|
||||
states_ptr=states,
|
||||
initstates_ptr=initial_states,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||
stride_chunk_states_head=chunk_states.stride(1),
|
||||
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||
stride_states_batch=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return states
|
||||
230
model_executor/layers/mamba/ops/ssd_combined.py
Normal file
230
model_executor/layers/mamba/ops/ssd_combined.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
return_intermediate_states=False,
|
||||
seq_idx=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
last_chunk_indices=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (seqlen, nheads)
|
||||
assert A.shape == (nheads,)
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if (
|
||||
x.stride(-1) != 1 and x.stride(0) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if (
|
||||
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
# which has a minimal implementation to understand the below operations
|
||||
# - as explained by the blog, mamba is a special case of causal attention
|
||||
# - the idea is to chunk the attention matrix and compute each
|
||||
# submatrix separately using different optimizations.
|
||||
# - see the blog and paper for a visualization of the submatrices
|
||||
# which we refer to in the comments below
|
||||
|
||||
# 1. Compute chunked cumsum of A * dt
|
||||
# - here dt may go through a softplus activation
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
states = _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
|
||||
)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states and
|
||||
# ii) seq_idx to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
cu_chunk_seqlens,
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None
|
||||
else None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
# account past causal states.
|
||||
# - if initial states are provided, then states information will be
|
||||
# augmented with initial_states.
|
||||
# - to do this properly, we need to account for example changes in
|
||||
# the continuous batch, therefore we introduce pseudo chunks, which is
|
||||
# a chunk that is split up each time an example changes.
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
_chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
cu_chunk_seqlens,
|
||||
out, # in-place update
|
||||
seq_idx,
|
||||
D=D,
|
||||
z=z,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
if return_intermediate_states:
|
||||
return states
|
||||
else:
|
||||
return states[last_chunk_indices]
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
return_intermediate_states=False,
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (seqlen, nheads, headdim)
|
||||
dt: (seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (seqlen, ngroups, dstate)
|
||||
C: (seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
cu_seqlens: (batch + 1,)
|
||||
cu_chunk_seqlens: (nchunks + 1,)
|
||||
last_chunk_indices: (batch,)
|
||||
seq_idx: (nchunks,)
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
Return:
|
||||
varlen_states: (batch, nheads, headdim, dstate)
|
||||
"""
|
||||
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
||||
assert seq_idx is not None
|
||||
|
||||
varlen_states = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=return_intermediate_states,
|
||||
seq_idx=seq_idx,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
state_dtype=state_dtype,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
157
model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
157
model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE": 64}),
|
||||
triton.Config({"BLOCK_SIZE": 128}),
|
||||
triton.Config({"BLOCK_SIZE": 256}),
|
||||
triton.Config({"BLOCK_SIZE": 512}),
|
||||
triton.Config({"BLOCK_SIZE": 1024}),
|
||||
triton.Config({"BLOCK_SIZE": 2048}),
|
||||
],
|
||||
key=["dim"],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr,
|
||||
out_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
# Strides
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_dim: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_dim: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
|
||||
# we have started a new sequence
|
||||
if prev_seq_idx != seq_idx:
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_initstates_batch
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = seq_idx
|
||||
states = tl.exp(dA_cs) * states + new_states
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
cu_chunk_seqlens,
|
||||
seq_idx,
|
||||
initial_states=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
nchunks, nheads, dim = states.shape
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
||||
|
||||
initial_states_strides = (
|
||||
(initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
|
||||
if initial_states is not None
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states_ptr=states,
|
||||
out_ptr=out,
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_dim=states.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_dim=out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out
|
||||
Reference in New Issue
Block a user