Files
xc-llm-ascend/vllm_ascend/ops/triton/mamba/causal_conv1d.py
SILONG ZENG 1e3c1e76bf [Lint]Add lint hooks for clang-format, shellcheck, forbidden imports, and boolean context manager checks (#7511)
### What this PR does / why we need it?
This PR introduces several upstream `vllm`-aligned lint hooks into
`vllm-ascend` and makes them part of the actual `pre-commit` flow.

Main changes in this PR:
- add `check-boolean-context-manager` to catch boolean expressions in
`with` statements
- add `check-forbidden-imports` to forbid direct `re` imports and
disallowed direct `triton` imports
- enable shell script linting through `tools/shellcheck.sh`
- add root `.clang-format` aligned with upstream `vllm`, enable
`clang-format` in `pre-commit`, temporarily **exclude all `csrc/**`**
from `clang-format` to avoid bringing a large native code reformat into
this PR

This PR focuses on landing the smaller and immediately useful lint
alignment first, without mixing in the larger requirements-management
migration.

### Does this PR introduce _any_ user-facing change?
No.

This PR only updates repository lint configuration, static checks, and
internal import/style enforcement. It does not change runtime behavior
or public interfaces.

### How was this patch tested?
Tested locally in the project virtual environment.

Commands used:
```bash
bash format.sh
```
Verified checks passed:
``` bash
ruff check...............................................................Passed
ruff format..............................................................Passed
codespell................................................................Passed
typos....................................................................Passed
clang-format.............................................................Passed
Lint GitHub Actions workflow files.......................................Passed
Lint shell scripts.......................................................Passed
Lint PNG exports from excalidraw.........................................Passed
Check for spaces in all filenames........................................Passed
Enforce __init__.py in Python packages...................................Passed
Check for forbidden imports..............................................Passed
Check for boolean ops in with-statements.................................Passed
Suggestion...............................................................Passed
- hook id: suggestion
- duration: 0s

To bypass pre-commit hooks, add --no-verify to git commit.
```
**note:**
clang-format is enabled but currently excludes all csrc/**


- vLLM version: v0.17.0
- vLLM main:
8b6325758c

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
2026-03-24 20:03:01 +08:00

693 lines
28 KiB
Python

# adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# mypy: ignore-errors
from typing import Any
import torch
import torch.nn.functional as F
from vllm.distributed import get_pcp_group
from vllm.forward_context import get_forward_context
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
initial_states: torch.Tensor | None = None,
return_final_states: bool = False,
final_states_out: torch.Tensor | None = None,
activation: str | None = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
activation: str | None = "silu",
conv_states: torch.Tensor | None = None,
has_initial_state: torch.Tensor | None = None,
cache_indices: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None,
metadata: Any | None = None,
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
forward_context = get_forward_context()
num_decodes = 0
attn_metadata = forward_context.attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is not None:
num_decodes = attn_metadata.num_decodes
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
width = weight.shape[1]
last_width_prefill_x = extract_last_width(x, query_start_loc[num_decodes:], conv_states.shape[-1])
if get_pcp_group().world_size > 1:
all_last_width_prefill_x = get_pcp_group().all_gather(last_width_prefill_x.unsqueeze(0).contiguous(), 0)
pcp_rank = get_pcp_group().rank_in_group
if pcp_rank > 0:
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[pcp_rank - 1, ...]
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., : (width - 1)],
)
)
if get_pcp_group().world_size > 1:
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[-1, ...]
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
def extract_last_width(x, start_loc, width):
end_loc = start_loc[1:]
offsets = torch.arange(width, device=x.device)
indices = end_loc.unsqueeze(1) - width + offsets.unsqueeze(0) # (num_seqs, width)
return x[:, indices].permute(1, 0, 2)
@triton.jit
def _causal_conv1d_update_kernel_npu_tiled(
# Pointers
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr, # (num_cache_lines, dim, state_len)
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
o_ptr, # same shape as x_ptr
batch: tl.int32,
dim: tl.constexpr,
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
state_len: tl.constexpr, # effective state_len computed in wrapper
num_cache_lines: tl.constexpr,
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr, # <= 6
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
# tiling
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
B_TILE: tl.constexpr, # batch tile
T_CHUNK: tl.constexpr, # token chunk for state update
):
# program ids
pid_b = tl.program_id(0) # batch-tile id
pid_c = tl.program_id(1) # channel-tile id
# channel indices for this program
idx_feats = pid_c * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N]
mask_w = idx_feats < dim
# preload weights once per program (shared by B_TILE sequences)
w_base = w_ptr + idx_feats * stride_w_dim
# define to avoid "undefined" in branches
w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32)
w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32)
if KERNEL_WIDTH >= 1:
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 2:
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 3:
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 4:
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 5:
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
if KERNEL_WIDTH >= 6:
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
# bias vector once per program
if HAS_BIAS:
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32)
else:
acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
# token index vector for chunked copy
tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK]
# process B_TILE sequences inside the same program instance
for bi in tl.static_range(0, B_TILE):
b = pid_b * B_TILE + bi # scalar tl.int32
lane_active = b < batch # scalar predicate
# -------------------------
# APC mapping (optional)
# -------------------------
if IS_APC_ENABLED:
conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32)
current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32)
else:
conv_state_init = tl.full((), 0, tl.int32)
current_last_index = tl.full((), 0, tl.int32)
# input cache line
conv_states_input_coord = tl.load(
conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0
).to(tl.int64)
if USE_PAD_SLOT:
lane_active = lane_active & (conv_states_input_coord != pad_slot_id)
# -------------------------
# varlen (optional): revise seqlen_run and state_len_run like original kernel does
# -------------------------
if IS_VARLEN:
qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64)
qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64)
seqlen_run = (qe - qs).to(tl.int32)
# revise effective state_len for shorter sequences (same formula as original)
state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32)
x_offset = (qs * stride_x_token).to(tl.int64)
o_offset = (qs * stride_o_token).to(tl.int64)
else:
seqlen_run = tl.full((), seqlen, tl.int32)
state_len_run = tl.full((), state_len, tl.int32)
x_offset = (b * stride_x_seq).to(tl.int64)
o_offset = (b * stride_o_seq).to(tl.int64)
# empty sequence -> skip (avoid early return because other lanes in tile)
lane_active = lane_active & (seqlen_run > 0)
# -------------------------
# spec decoding offset (optional)
# -------------------------
if IS_SPEC_DECODING:
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1
shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode
else:
conv_state_token_offset = tl.full((), 0, tl.int64)
shift = seqlen_run # normal mode shift by seqlen
# -------------------------
# STEP 1: read initial history cols BEFORE state update (out==x safe)
# -------------------------
conv_states_base = (
conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim
)
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
# define history vectors as zeros then load conditionally
col0 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col1 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col2 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col3 = tl.zeros((BLOCK_N,), dtype=tl.float16)
col4 = tl.zeros((BLOCK_N,), dtype=tl.float16)
if KERNEL_WIDTH >= 2:
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 3:
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 4:
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 5:
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
if KERNEL_WIDTH >= 6:
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
tl.float16
)
# -------------------------
# STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block)
# Semantics: conv_state <- concat(old_state, x)[-state_len_run:].
# - If seqlen_run >= state_len_run: dst[:] = x[seqlen_run - state_len_run : seqlen_run]
# - Else: keep = state_len_run - seqlen_run,
# dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run]
# -------------------------
# output cache line
conv_states_offset = tl.load(
conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0
).to(tl.int64)
use_shift = seqlen_run < state_len_run
use_tail = seqlen_run >= state_len_run
zero_i32 = tl.full((), 0, tl.int32)
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32)
tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32)
# base pointers
state_src_base = (
conv_state_ptr
+ conv_states_input_coord * stride_conv_state_seq
+ conv_state_token_offset * stride_conv_state_tok
+ idx_feats * stride_conv_state_dim
)
state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim
x_base = x_ptr + x_offset + idx_feats * stride_x_dim
# A) shift old state into dst[0:keep_shift) (only when seqlen_run < state_len_run)
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK]
m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run)
m = (
(lane_active & m_tok)[:, None]
& mask_w[None, :]
& (conv_states_input_coord < num_cache_lines)
& (conv_states_offset < num_cache_lines)
)
src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
vals = tl.load(src_ptrs, mask=m, other=0.0)
tl.store(dst_ptrs, vals, mask=m)
# B) append x into dst[keep_shift : keep_shift+seqlen_run) (only when seqlen_run < state_len_run)
for t0 in tl.static_range(0, seqlen, T_CHUNK):
x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK]
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
tl.store(dst_ptrs, x_vals, mask=m)
# C) if seqlen_run >= state_len_run, overwrite dst with the tail of x
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK]
m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run)
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
tl.store(dst_ptrs, x_vals, mask=m)
# -------------------------
# STEP 3/4/5: causal conv1d (+ optional SiLU) and store output
# This is original STEP3~5, but per-lane and without debug_barrier.
# -------------------------
x_base_1d = x_base
o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim
# accumulator preload (bias)
acc_preload = acc_bias
# compute each token; keep tl.range so varlen can use seqlen_run as runtime trip count (like original)
for idx_token in tl.range(seqlen_run):
acc = acc_preload
# same selection logic as original (unrolled by KERNEL_WIDTH)
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 1:
# only x[t] * w0
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
matrix_w = w_col0
elif KERNEL_WIDTH == 2:
if j == 1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 5:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
elif KERNEL_WIDTH == 6:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
matrix_x = col4
elif j == 5:
matrix_w = w_col5
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N]
# roll history window
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
elif KERNEL_WIDTH == 5:
col0 = col1
col1 = col2
col2 = col3
col3 = matrix_x
elif KERNEL_WIDTH == 6:
col0 = col1
col1 = col2
col2 = col3
col3 = col4
col4 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1.0 + tl.exp(-acc))
# store output
o_ptrs = o_base_1d + idx_token * stride_o_token
tl.store(o_ptrs, acc, mask=lane_active & mask_w)
def causal_conv1d_update_npu(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
activation: bool | str | None = None,
conv_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
validate_data=False,
):
"""
x: Input tensor which can take the following shapes:
- `[batch, dim]` - single token prediction
- `[batch, dim, seqlen]` - single or multiple tokens prediction
- `[num_tokens, dim]` - continuous batching, where num_tokens is
the total tokens of all sequences in that batch
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
num_accepted_tokens: (batch,), dtype int32
If not None, it indicates the number of accepted tokens for each
sequence in the batch.
This is used in speculative decoding, where the conv_state is updated
in a sliding window manner.
query_start_loc: (batch + 1,) int32
If not None, the inputs is given in a varlen fashion and this indicates
the starting index of each sequence in the batch.
max_query_len: int
If query_start_loc is not None, this indicates the maximum query
length in the batch.
pad_slot_id: int
if conv_state_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
weight = weight.transpose(0, 1).contiguous()
conv_state = conv_state.transpose(1, 2).contiguous()
if validate_data:
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
original_x_dtype = x.dtype
x = x.to(conv_state.dtype)
unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(1)
if query_start_loc is None:
batch, seqlen, dim = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)
seqlen = max_query_len
width, _ = weight.shape
num_cache_lines, state_len_total, _ = conv_state.size()
# overwrite-on-x strategy same as original
out = x
stride_w_width, stride_w_dim = weight.stride()
if query_start_loc is None:
stride_x_seq, stride_x_token, stride_x_dim = x.stride()
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
else:
stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride()
stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0
# effective state_len exactly as original
if num_accepted_tokens is not None:
eff_state_len = width - 1 + (seqlen - 1)
else:
eff_state_len = width - 1
np2_statelen = triton.next_power_of_2(eff_state_len)
# -------- tiling heuristic--------
# keep program count around ~[80..160]
# vector core 40
# TODO: use driver to get the vector core num
CORE_HINT = 40
# channel tile: 512 when dim large (reduce tasks), else 256
block_n = 512 if dim >= 512 else 256
g = triton.cdiv(dim, block_n)
target = 2 * CORE_HINT # ~80
b_tile_raw = max(1, (batch * g + target - 1) // target)
# clamp to small set
if b_tile_raw <= 1:
b_tile = 1
elif b_tile_raw <= 2:
b_tile = 2
elif b_tile_raw <= 4:
b_tile = 4
else:
b_tile = 8
# token chunk based on block_n (32KB UB idea); conservative
t_chunk = 1 if block_n == 512 else 48
def grid(META):
return (
triton.cdiv(batch, META["B_TILE"]),
triton.cdiv(dim, META["BLOCK_N"]),
)
_causal_conv1d_update_kernel_npu_tiled[grid](
x,
weight,
bias,
conv_state,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
block_idx_last_scheduled_token,
initial_state_idx,
out,
batch,
dim,
seqlen,
eff_state_len,
num_cache_lines,
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
pad_slot_id,
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=block_n,
B_TILE=b_tile,
T_CHUNK=t_chunk,
)
if unsqueeze:
out = out.squeeze(1)
return out.to(original_x_dtype)