mkdir triton package and move triton files (#4420)
### What this PR does / why we need it?
mkdir triton package and move triton files
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
0
vllm_ascend/ops/triton/__init__.py
Normal file
0
vllm_ascend/ops/triton/__init__.py
Normal file
0
vllm_ascend/ops/triton/fla/__init__.py
Normal file
0
vllm_ascend/ops/triton/fla/__init__.py
Normal file
299
vllm_ascend/ops/triton/fla/fla.py
Normal file
299
vllm_ascend/ops/triton/fla/fla.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
# mypy: ignore-errors
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
MAX_CORES = 65535
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X_base
|
||||
N, # number of columns in X_base
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
N_CORES: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X_base and Y_base it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
|
||||
BLOCK_ROWS = M if M < N_CORES else N_CORES
|
||||
n_iters = M // BLOCK_ROWS
|
||||
remain = M % BLOCK_ROWS
|
||||
if row < remain:
|
||||
n_iters = n_iters + 1
|
||||
|
||||
for i in tl.range(n_iters):
|
||||
X_base = X + (i * BLOCK_ROWS *
|
||||
stride_x_row) + row * stride_x_row + group * N
|
||||
Y_base = Y + (i * BLOCK_ROWS *
|
||||
stride_y_row) + row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z_base = Z + (i * BLOCK_ROWS *
|
||||
stride_z_row) + row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean_base = Mean + (i * BLOCK_ROWS) + group * M
|
||||
Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
|
||||
W_base = W + group * N
|
||||
if HAS_BIAS:
|
||||
B_base = B + group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean_base + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd_base + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W_base + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B_base + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z_base + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y_base + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm else None)
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
|
||||
with torch.npu.device(x.device.index):
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
N_CORES=MAX_CORES,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = F.normalize(query, p=2, dim=-1)
|
||||
key = F.normalize(key, p=2, dim=-1)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, sequence_length, num_heads, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
tot_heads = num_heads + pad_size
|
||||
scale = 1 / (query.shape[-1]**0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) -
|
||||
g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -(
|
||||
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
|
||||
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
|
||||
k_head_dim, v_head_dim).to(value) if
|
||||
initial_state is None else initial_state.to(value))
|
||||
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, tot_heads // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) *
|
||||
decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
|
||||
(k_i *
|
||||
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
|
||||
-1, -2) @ v_new)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
|
||||
core_attn_out.shape[1], -1,
|
||||
core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :num_heads]
|
||||
core_attn_out = core_attn_out.transpose(1,
|
||||
2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
300
vllm_ascend/ops/triton/fla/sigmoid_gating.py
Normal file
300
vllm_ascend/ops/triton/fla/sigmoid_gating.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
|
||||
import os
|
||||
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def div_normal(x, y):
|
||||
return x / y
|
||||
|
||||
div = div_normal
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_INITIAL_STATE':
|
||||
lambda args: args['h0'] is not None,
|
||||
'IS_VARLEN':
|
||||
lambda args: args['cu_seqlens'] is not None,
|
||||
"IS_CONTINUOUS_BATCHING":
|
||||
lambda args: args['ssm_state_indices'] is not None,
|
||||
"IS_SPEC_DECODING":
|
||||
lambda args: args['num_accepted_tokens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['N', 'T'])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.constexpr, # num of sequences
|
||||
T: tl.constexpr, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.
|
||||
constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_KDA: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
|
||||
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv + HV * i_t
|
||||
|
||||
if not IS_KDA:
|
||||
p_g = g + bos * HV + i_hv + HV * i_t
|
||||
else:
|
||||
p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k
|
||||
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
# b_h *= tl.exp(b_g)
|
||||
if not IS_KDA:
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
b_h *= exp(b_g)
|
||||
else:
|
||||
b_gk = tl.load(p_gk).to(tl.float32)
|
||||
b_h *= exp(b_gk[:, None])
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_final_state_token
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_INITIAL_STATE':
|
||||
lambda args: args['h0'] is not None,
|
||||
'IS_VARLEN':
|
||||
lambda args: args['cu_seqlens'] is not None,
|
||||
"IS_CONTINUOUS_BATCHING":
|
||||
lambda args: args['ssm_state_indices'] is not None,
|
||||
"IS_SPEC_DECODING":
|
||||
lambda args: args['num_accepted_tokens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['N', 'T'])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.constexpr, # num of sequences
|
||||
T: tl.constexpr, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.
|
||||
constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
|
||||
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv + HV * i_t
|
||||
p_g = g + bos * HV + i_hv + HV * i_t
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
# b_h *= tl.exp(b_g)
|
||||
b_h *= exp(b_g)
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_final_state_token
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
0
vllm_ascend/ops/triton/mamba/__init__.py
Normal file
0
vllm_ascend/ops/triton/mamba/__init__.py
Normal file
539
vllm_ascend/ops/triton/mamba/casual_conv1d.py
Normal file
539
vllm_ascend/ops/triton/mamba/casual_conv1d.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out[..., :(width - 1)].copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
if cache_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]]
|
||||
if has_initial_state[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def _causal_conv1d_update_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, # (batch, dim, seqlen)
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr,
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
intermediate_conv_window_ptr,
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr,
|
||||
state_len: tl.constexpr,
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_inter_seq: tl.constexpr,
|
||||
stride_inter_step: tl.constexpr,
|
||||
stride_inter_dim: tl.constexpr,
|
||||
stride_inter_win: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SAVE_INTERMEDIATE: tl.constexpr,
|
||||
):
|
||||
# ruff: noqa: E501
|
||||
idx_seq = tl.program_id(0)
|
||||
if idx_seq >= batch:
|
||||
return
|
||||
|
||||
# [BLOCK_N,] elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
if IS_SPEC_DECODING:
|
||||
# The rolling of conv state:
|
||||
#
|
||||
# Before forward, the conv_state is:
|
||||
# [history1, history2, ..., historyM].
|
||||
#
|
||||
# After forward, the conv_state becomes:
|
||||
# [history2, ..., historyM, draft1, draft2, ..., draftN].
|
||||
#
|
||||
# After acceptance, it becomes:
|
||||
#
|
||||
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
||||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||
# - and so on.
|
||||
conv_state_token_offset = tl.load(num_accepted_tokens_ptr +
|
||||
idx_seq) - 1
|
||||
else:
|
||||
conv_state_token_offset = 0
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
if KERNEL_WIDTH >= 2:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH == 5:
|
||||
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
|
||||
#col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
|
||||
# STEP 2: assume state_len > seqlen
|
||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# The conv_state updates works in a sliding window manner,
|
||||
# at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
|
||||
) # [BLOCK_N]
|
||||
|
||||
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens - VAL >= 0)[:, None]
|
||||
& (idx_tokens - VAL < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
tl.debug_barrier()
|
||||
|
||||
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
||||
|
||||
conv_state_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
conv_state_ptrs_target = (conv_state_base +
|
||||
(idx_tokens * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||
|
||||
# STEP 3: init accumulator
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
acc_preload = tl.load(bias, mask=mask_bias,
|
||||
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||
else:
|
||||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
|
||||
# STEP 4:
|
||||
# PRE-LOAD WEIGHTS
|
||||
# first kernel column, configured for weights to handle BLOCK_N features in range
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
mask_w = idx_feats < dim
|
||||
if KERNEL_WIDTH >= 2:
|
||||
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
|
||||
x_base_1d = x_base # starting of chunk [BLOCK_N]
|
||||
mask_x_1d = idx_feats < dim
|
||||
|
||||
# STEP 5: compute each token
|
||||
for idx_token in tl.static_range(seqlen):
|
||||
acc = acc_preload
|
||||
|
||||
matrix_w = w_col0
|
||||
matrix_x = col0
|
||||
for j in tl.static_range(KERNEL_WIDTH):
|
||||
if KERNEL_WIDTH == 2:
|
||||
if j == 1: # KERNEL_WIDTH-1:
|
||||
matrix_w = w_col1
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 3:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 4:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
matrix_x = col2
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
|
||||
acc += matrix_x * matrix_w # [BLOCK_N]
|
||||
|
||||
if KERNEL_WIDTH == 2:
|
||||
col0 = matrix_x
|
||||
elif KERNEL_WIDTH == 3:
|
||||
col0 = col1
|
||||
col1 = matrix_x
|
||||
elif KERNEL_WIDTH == 4:
|
||||
col0 = col1
|
||||
col1 = col2
|
||||
col2 = matrix_x
|
||||
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
# mask_1d = (idx_token < seqlen) & (
|
||||
# idx_feats < dim
|
||||
# ) # token-index # feature-index
|
||||
maskL = idx_feats < dim
|
||||
maskR = tl.full(maskL.shape, False, tl.int1)
|
||||
mask_1d = tl.where(idx_token < seqlen, maskL, maskR)
|
||||
|
||||
o_ptrs = (o_ptr + (idx_seq) * stride_o_seq +
|
||||
idx_token * stride_o_token + (idx_feats * stride_o_dim))
|
||||
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
if SAVE_INTERMEDIATE:
|
||||
# Save the window state after consuming this token
|
||||
# Layout: [seq(cache line), step, dim, win(K-1)]
|
||||
base_ptr = (intermediate_conv_window_ptr +
|
||||
conv_state_batch_coord * stride_inter_seq +
|
||||
idx_token * stride_inter_step +
|
||||
idx_feats * stride_inter_dim)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
|
||||
|
||||
|
||||
def causal_conv1d_update_npu(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
intermediate_conv_window: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
[shape=2: single token prediction]
|
||||
[shape=3: single or multiple tokens prediction]
|
||||
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if validate_data:
|
||||
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||
assert pad_slot_id is not None
|
||||
assert x.stride(1) == 1
|
||||
if isinstance(activation, bool):
|
||||
activation = "silu" if activation is True else None
|
||||
elif activation is not None:
|
||||
assert activation in ["silu", "swish"]
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
_, width = weight.shape
|
||||
# conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
num_cache_lines, _, state_len = conv_state.size()
|
||||
|
||||
if validate_data:
|
||||
assert dim == weight.size(0)
|
||||
assert (
|
||||
conv_state.stride(-2) == 1
|
||||
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
||||
assert state_len >= width - 1
|
||||
# when above happens, we don't shift-left to keep any records in conv_state
|
||||
assert dim == conv_state.size(1)
|
||||
if conv_state_indices is None:
|
||||
assert conv_state.size(0) >= batch
|
||||
else:
|
||||
assert (batch, ) == conv_state_indices.shape
|
||||
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
||||
|
||||
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||
out = x
|
||||
stride_w_dim, stride_w_width = weight.stride()
|
||||
|
||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
|
||||
) # X (batch, dim, seqlen)
|
||||
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = (conv_state_indices.stride(0)
|
||||
if conv_state_indices is not None else 0)
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
batch,
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
# prepare intermediate buffer strides if provided
|
||||
if intermediate_conv_window is not None:
|
||||
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
|
||||
intermediate_conv_window.stride(0),
|
||||
intermediate_conv_window.stride(1),
|
||||
intermediate_conv_window.stride(2),
|
||||
intermediate_conv_window.stride(3),
|
||||
)
|
||||
else:
|
||||
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
|
||||
|
||||
_causal_conv1d_update_kernel[grid](
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
conv_state,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
intermediate_conv_window
|
||||
if intermediate_conv_window is not None else x,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
dim,
|
||||
seqlen,
|
||||
state_len,
|
||||
num_cache_lines,
|
||||
# stride
|
||||
stride_x_seq,
|
||||
stride_x_dim,
|
||||
stride_x_token,
|
||||
stride_w_dim,
|
||||
stride_w_width,
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_state_indices,
|
||||
stride_inter_seq,
|
||||
stride_inter_step,
|
||||
stride_inter_dim,
|
||||
stride_inter_win,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
# others
|
||||
pad_slot_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
BLOCK_N=128,
|
||||
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
|
||||
)
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
Reference in New Issue
Block a user