first commit
This commit is contained in:
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1092
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
1092
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
File diff suppressed because it is too large
Load Diff
168
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
168
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row: tl.int64,
|
||||
stride_y_row: tl.int64,
|
||||
stride_z_row: tl.int64,
|
||||
M: tl.int64, # number of rows in X
|
||||
N: tl.int64, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = _layer_norm_fwd(x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True)
|
||||
|
||||
return y.reshape(x_shape_og)
|
||||
414
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
414
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
|
||||
>= version.parse("3.0.0"))
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics({
|
||||
"HAS_STATE_BATCH_INDICES":
|
||||
lambda args: args["state_batch_indices_ptr"] is not None
|
||||
})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
@triton.jit
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
A_ptr,
|
||||
B_ptr,
|
||||
C_ptr,
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch,
|
||||
stride_state_head,
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_bias_dim,
|
||||
stride_A_head,
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
else:
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
|
||||
offs_n[None, :] * stride_A_dstate)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
|
||||
((16, 4) if dstate <= 32 else
|
||||
((8, 4) if dstate <= 64 else
|
||||
((4, 4) if dstate <= 128 else ((4, 8))))))
|
||||
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
|
||||
-1) == 0 and dt_bias.stride(-1) == 0
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads // ngroups,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0),
|
||||
dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
def selective_scan_fn(u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
dt_bias: (dim,) or (dim)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended with 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state index
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3 and query_start_loc is None:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 2 and query_start_loc is not None:
|
||||
B = B.unsqueeze(0)
|
||||
if C.dim() == 3 and query_start_loc is None:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 2 and query_start_loc is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
|
||||
query_start_loc, cache_indices, has_initial_state,
|
||||
ssm_states, pad_slot_id)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
else:
|
||||
return z # output written inplace to z
|
||||
242
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
242
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
ngroups: tl.constexpr,
|
||||
stride_a_seqlen: tl.int64,
|
||||
stride_a_head: tl.int64,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_outm: tl.int64,
|
||||
stride_outn: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# compute a * b.T
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0).to(dot_dtype)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
|
||||
(offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
# Zero out the results that are not from the same request
|
||||
# in the varlen batch
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||
mask=offs_n < chunk_size_limit,
|
||||
other=-2)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||
offs_n[None, :] * stride_outn)
|
||||
tl.store(out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) &
|
||||
(offs_n[None, :] < chunk_size))
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (seqlen, ngroups, k)
|
||||
b: (seqlen, ngroups, k)
|
||||
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||
b = b.contiguous()
|
||||
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||
device=a.device,
|
||||
dtype=out_dtype)
|
||||
dot_dtype = (tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
||||
(tl.float16 if a.dtype == torch.float16
|
||||
or b.dtype == torch.float16 else tl.float32))
|
||||
grid = lambda META: (triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
out_ptr=out,
|
||||
seq_idx_ptr=seq_idx,
|
||||
seqlen=seqlen,
|
||||
chunk_size=chunk_size,
|
||||
K=k,
|
||||
ngroups=ngroups,
|
||||
stride_a_seqlen=a.stride(0),
|
||||
stride_a_head=a.stride(1),
|
||||
stride_ak=a.stride(2),
|
||||
stride_b_seqlen=b.stride(0),
|
||||
stride_b_head=b.stride(1),
|
||||
stride_bk=b.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_outm=out.stride(-2),
|
||||
stride_outn=out.stride(-1),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
IS_CAUSAL=causal,
|
||||
dot_dtype=dot_dtype,
|
||||
)
|
||||
return out
|
||||
527
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
527
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
cb_ptr,
|
||||
x_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
C_ptr,
|
||||
states_ptr,
|
||||
D_ptr,
|
||||
initstates_ptr,
|
||||
chunk_indices_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
chunk_size: tl.constexpr,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_cb_chunk: tl.int64,
|
||||
stride_cb_head: tl.int64,
|
||||
stride_cb_csize_m: tl.int64,
|
||||
stride_cb_csize_k: tl.constexpr,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_z_seqlen: tl.int64,
|
||||
stride_z_head: tl.int64,
|
||||
stride_z_hdim: tl.constexpr,
|
||||
stride_out_seqlen: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_hdim: tl.constexpr,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
stride_C_seqlen: tl.int64,
|
||||
stride_C_head: tl.int64,
|
||||
stride_C_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
stride_D_head: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
if not HAS_INITSTATES:
|
||||
c_idx = pid_c
|
||||
c_off = 0
|
||||
else:
|
||||
c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_cb_head
|
||||
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
||||
|
||||
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
|
||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
||||
mask=c_idx >= 1,
|
||||
other=0)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states, we only need seq_idx_m to point
|
||||
# what is the current seq_idx
|
||||
|
||||
# get current seq idx
|
||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
||||
|
||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||
# so this edge case is taken care of
|
||||
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
|
||||
) # if a seq is changed exactly on boundary
|
||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||
):
|
||||
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||
mask=offs_m < chunk_size,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# - handle chunk state limit
|
||||
if HAS_INITSTATES:
|
||||
# have to split this if otherwise compilation will have problems
|
||||
dA_cs_m_boundary = 0.0
|
||||
|
||||
# get the c_idx for the next (logica) chunk
|
||||
c_idx_n = tl.load(
|
||||
chunk_indices_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=-1 # to trigger different chunk
|
||||
)
|
||||
|
||||
# - there are things to consider
|
||||
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
|
||||
# contribution of past states
|
||||
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
|
||||
# encroach into the next sequence, where c_off_n is the offset of the next
|
||||
# (logical) chunk.
|
||||
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
|
||||
# (logical) chunk indices.
|
||||
|
||||
if (c_idx == c_idx_n) or c_off > 0:
|
||||
|
||||
# get the next offset
|
||||
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=chunk_size)
|
||||
|
||||
# in this case, adjust down the chunk_size_limit
|
||||
if c_idx == c_idx_n:
|
||||
chunk_size_limit = min(c_off_n, chunk_size_limit)
|
||||
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
||||
# to increase pointer by pid_m (it is a constant offset,
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
# - handle seq idx when HAS_INITSTATES==False
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Without the if (pid_c > -1), with Triton 2.1.0, I get
|
||||
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
|
||||
# With Triton 2.2.0, this works
|
||||
if IS_TRITON_22 or c_idx > -1:
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
||||
offs_k_dstate[None, :] * stride_C_dstate)
|
||||
|
||||
prev_states_ptrs = prev_states_ptr + (
|
||||
offs_n[None, :] * prev_states_hdim +
|
||||
offs_k_dstate[:, None] * prev_states_dstate)
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
||||
else:
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate),
|
||||
other=0.0)
|
||||
|
||||
prev_states = tl.load(prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
else:
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0)
|
||||
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
|
||||
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
|
||||
offs_k[None, :] * stride_cb_csize_k)
|
||||
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
|
||||
offs_n[None, :] * stride_x_hdim)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = chunk_size_limit if not IS_CAUSAL else min(
|
||||
(pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
for k in range(0, K_MAX, BLOCK_SIZE_K):
|
||||
cb = tl.load(cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) &
|
||||
(offs_k[None, :] < chunk_size - k),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
mask = offs_m[:, None] >= k + offs_k[None, :]
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(x_ptr.dtype.element_ty)
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
acc += tl.dot(cb, x)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
if HAS_D:
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n,
|
||||
mask=offs_n < hdim,
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen +
|
||||
offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0).to(tl.float32)
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||
stride_z_hdim * offs_out_n[None, :])
|
||||
z = tl.load(z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim),
|
||||
other=0.0).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||
offs_out_n[None, :] * stride_out_hdim)
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim))
|
||||
|
||||
|
||||
def _chunk_scan_fwd(
|
||||
cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
out,
|
||||
seq_idx,
|
||||
D=None,
|
||||
z=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
):
|
||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = C.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert C.shape == (seqlen, ngroups, dstate)
|
||||
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if initial_states is not None:
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
headdim, META['BLOCK_SIZE_N']), nchunks
|
||||
if chunk_offsets is None else len(chunk_offsets), nheads)
|
||||
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb_ptr=cb,
|
||||
x_ptr=x,
|
||||
z_ptr=z,
|
||||
out_ptr=out,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
C_ptr=C,
|
||||
states_ptr=states,
|
||||
D_ptr=D,
|
||||
initstates_ptr=initial_states,
|
||||
chunk_indices_ptr=chunk_indices,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
|
||||
chunk_size=chunk_size,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_cb_chunk=cb.stride(0),
|
||||
stride_cb_head=cb.stride(1),
|
||||
stride_cb_csize_m=cb.stride(2),
|
||||
stride_cb_csize_k=cb.stride(3),
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_z_seqlen=z_strides[0],
|
||||
stride_z_head=z_strides[1],
|
||||
stride_z_hdim=z_strides[2],
|
||||
stride_out_seqlen=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_hdim=out.stride(2),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
stride_C_seqlen=C.stride(0),
|
||||
stride_C_head=C.stride(1),
|
||||
stride_C_dstate=C.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
stride_D_head=D.stride(0) if D is not None else 0,
|
||||
IS_CAUSAL=True,
|
||||
HAS_D=D is not None,
|
||||
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
|
||||
HAS_Z=z is not None,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return
|
||||
724
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
724
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
@@ -0,0 +1,724 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_H': 2}),
|
||||
triton.Config({'BLOCK_SIZE_H': 4}),
|
||||
triton.Config({'BLOCK_SIZE_H': 8}),
|
||||
triton.Config({'BLOCK_SIZE_H': 16}),
|
||||
triton.Config({'BLOCK_SIZE_H': 32}),
|
||||
triton.Config({'BLOCK_SIZE_H': 64}),
|
||||
],
|
||||
key=['chunk_size', 'nheads'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
dt_ptr,
|
||||
A_ptr,
|
||||
dt_bias_ptr,
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
# Matrix dimension
|
||||
seqlen,
|
||||
nheads: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
dt_min: tl.constexpr,
|
||||
dt_max: tl.constexpr,
|
||||
# Strides
|
||||
stride_dt_seqlen: tl.int64,
|
||||
stride_dt_head: tl.constexpr,
|
||||
stride_A_head: tl.constexpr,
|
||||
stride_dt_bias_head: tl.constexpr,
|
||||
stride_dt_out_head: tl.int64,
|
||||
stride_dt_out_chunk: tl.int64,
|
||||
stride_dt_out_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr,
|
||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
# if dt is long, may cause problems, so use 64 bit
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
|
||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head +
|
||||
offs_c[None, :] * stride_dt_seqlen)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head +
|
||||
offs_c[None, :] * stride_dt_out_csize)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
|
||||
offs_c[None, :] * stride_dA_cs_csize)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
dt = tl.load(dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) &
|
||||
(offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head,
|
||||
mask=offs_h < nheads,
|
||||
other=0.0).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
|
||||
dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
|
||||
0.0)
|
||||
tl.store(dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
states_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
|
||||
offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
|
||||
offs_k[:, None] * stride_b_seqlen)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr +
|
||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
seq_idx_last = tl.load(seq_idx_ptr +
|
||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=-1)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
offs_n[None, :] * stride_states_dstate)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
chunk_states_ptr,
|
||||
cu_seqlens_ptr,
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_chunk_states_chunk: tl.int64,
|
||||
stride_chunk_states_head: tl.int64,
|
||||
stride_chunk_states_hdim: tl.int64,
|
||||
stride_chunk_states_dstate: tl.constexpr,
|
||||
stride_states_batch: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
# are boundary conditions at a chunk boundary) and initstates (which are boundary
|
||||
# conditions when a new example in a cont batch starts)
|
||||
initstates_ptr += pid_h * stride_init_states_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
|
||||
offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
|
||||
offs_k[:, None] * stride_b_seqlen)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) *
|
||||
stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
||||
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_k[None, :] < chunk_size_limit - k) &
|
||||
(offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < dstate) &
|
||||
(offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if ((start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)):
|
||||
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim +
|
||||
offs_n[None, :] * stride_chunk_states_dstate)
|
||||
else:
|
||||
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim +
|
||||
offs_n[None, :] * stride_chunk_states_dstate)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch +
|
||||
offs_m[:, None] * stride_init_states_hdim +
|
||||
offs_n[None, :] * stride_init_states_dstate)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(dA_cumsum_ptr +
|
||||
(start_idx - pid_c * chunk_size -
|
||||
1) * stride_dA_cs_csize).to(
|
||||
tl.float32)
|
||||
|
||||
past_states = tl.load(past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
offs_n[None, :] * stride_states_dstate)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf"))):
|
||||
seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads, )
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, )
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
dt_out = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
dA_cumsum = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
grid_chunk_cs = lambda META: (nchunks,
|
||||
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt_ptr=dt,
|
||||
A_ptr=A,
|
||||
dt_bias_ptr=dt_bias,
|
||||
dt_out_ptr=dt_out,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seqlen=seqlen,
|
||||
nheads=nheads,
|
||||
chunk_size=chunk_size,
|
||||
dt_min=dt_limit[0],
|
||||
dt_max=dt_limit[1],
|
||||
stride_dt_seqlen=dt.stride(0),
|
||||
stride_dt_head=dt.stride(1),
|
||||
stride_A_head=A.stride(0),
|
||||
stride_dt_bias_head=dt_bias.stride(0)
|
||||
if dt_bias is not None else 0,
|
||||
stride_dt_out_head=dt_out.stride(0),
|
||||
stride_dt_out_chunk=dt_out.stride(1),
|
||||
stride_dt_out_csize=dt_out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
DT_SOFTPLUS=dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx=None,
|
||||
states=None,
|
||||
states_in_fp32=True):
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if states is not None:
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty((nchunks, nheads, headdim, dstate),
|
||||
device=x.device,
|
||||
dtype=states_dtype)
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
states_ptr=states,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
chunk_states,
|
||||
initial_states=None):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
batch = cu_seqlens.shape[0] - 1
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (total_seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
chunk_states_ptr=chunk_states,
|
||||
cu_seqlens_ptr=cu_seqlens,
|
||||
states_ptr=states,
|
||||
initstates_ptr=initial_states,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||
stride_chunk_states_head=chunk_states.stride(1),
|
||||
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||
stride_states_batch=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None)
|
||||
return states
|
||||
238
vllm/model_executor/layers/mamba/ops/ssd_combined.py
Normal file
238
vllm/model_executor/layers/mamba/ops/ssd_combined.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
||||
chunk_state_varlen)
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (seqlen, nheads)
|
||||
assert A.shape == (nheads, )
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if x.stride(-1) != 1 and x.stride(
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if z is not None and z.stride(-1) != 1 and z.stride(
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
|
||||
dstate)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
# which has a minimal implementation to understand the below operations
|
||||
# - as explained by the blog, mamba is a special case of causal attention
|
||||
# - the idea is to chunk the attention matrix and compute each
|
||||
# submatrix separately using different optimizations.
|
||||
# - see the blog and paper for a visualization of the submatrices
|
||||
# which we refer to in the comments below
|
||||
|
||||
# 1. Compute chunked cumsum of A * dt
|
||||
# - here dt may go through a softplus activation
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
states = _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx=seq_idx,
|
||||
states_in_fp32=True)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states
|
||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else
|
||||
None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
chunk_offsets=chunk_offsets)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C,
|
||||
B,
|
||||
chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
# account past causal states.
|
||||
# - if initial states are provided, then states information will be
|
||||
# augmented with initial_states.
|
||||
# - to do this properly, we need to account for example changes in
|
||||
# the continuous batch, therefore we introduce pseudo chunks, which is
|
||||
# a chunk that is split up each time an example changes.
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
_chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
out, # in-place update
|
||||
seq_idx,
|
||||
D=D,
|
||||
z=z,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
varlen_states = chunk_state_varlen(
|
||||
B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
states,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (seqlen, nheads, headdim)
|
||||
dt: (seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (seqlen, ngroups, dstate)
|
||||
C: (seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
seq_idx: (seqlen)
|
||||
cu_seqlens: (batch + 1)
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
Return:
|
||||
varlen_states: (batch, nheads, headdim, dstate)
|
||||
"""
|
||||
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
||||
assert seq_idx is not None
|
||||
|
||||
varlen_states = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
state_dtype=state_dtype)
|
||||
|
||||
return varlen_states
|
||||
200
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
200
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64}),
|
||||
triton.Config({'BLOCK_SIZE': 128}),
|
||||
triton.Config({'BLOCK_SIZE': 256}),
|
||||
triton.Config({'BLOCK_SIZE': 512}),
|
||||
triton.Config({'BLOCK_SIZE': 1024}),
|
||||
triton.Config({'BLOCK_SIZE': 2048}),
|
||||
],
|
||||
key=['dim'],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr,
|
||||
out_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
# Strides
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_dim: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_dim: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||
1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptr += pid_h * stride_initstates_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
# - states will be the past state of the sequence that continues on the current check
|
||||
if not HAS_INITSTATES:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
else:
|
||||
initstates_ptr += offs_m * stride_initstates_dim
|
||||
initstates_ptrs = initstates_ptr
|
||||
# - for cont batches, for the first chunk mean it will be the first batch's
|
||||
# init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks - 1):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale_mask = True
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr +
|
||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||
stride_seq_idx_seqlen)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
initial_states=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
nchunks, nheads, dim = states.shape
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((nchunks, nheads, dim),
|
||||
device=states.device,
|
||||
dtype=out_dtype)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2))
|
||||
if initial_states is not None else (0, 0, 0))
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states_ptr=states,
|
||||
out_ptr=out,
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_offsets)
|
||||
if chunk_offsets is not None else 0,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_dim=states.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_dim=out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out
|
||||
Reference in New Issue
Block a user