### What this PR does / why we need it?
This PR introduces several upstream `vllm`-aligned lint hooks into
`vllm-ascend` and makes them part of the actual `pre-commit` flow.
Main changes in this PR:
- add `check-boolean-context-manager` to catch boolean expressions in
`with` statements
- add `check-forbidden-imports` to forbid direct `re` imports and
disallowed direct `triton` imports
- enable shell script linting through `tools/shellcheck.sh`
- add root `.clang-format` aligned with upstream `vllm`, enable
`clang-format` in `pre-commit`, temporarily **exclude all `csrc/**`**
from `clang-format` to avoid bringing a large native code reformat into
this PR
This PR focuses on landing the smaller and immediately useful lint
alignment first, without mixing in the larger requirements-management
migration.
### Does this PR introduce _any_ user-facing change?
No.
This PR only updates repository lint configuration, static checks, and
internal import/style enforcement. It does not change runtime behavior
or public interfaces.
### How was this patch tested?
Tested locally in the project virtual environment.
Commands used:
```bash
bash format.sh
```
Verified checks passed:
``` bash
ruff check...............................................................Passed
ruff format..............................................................Passed
codespell................................................................Passed
typos....................................................................Passed
clang-format.............................................................Passed
Lint GitHub Actions workflow files.......................................Passed
Lint shell scripts.......................................................Passed
Lint PNG exports from excalidraw.........................................Passed
Check for spaces in all filenames........................................Passed
Enforce __init__.py in Python packages...................................Passed
Check for forbidden imports..............................................Passed
Check for boolean ops in with-statements.................................Passed
Suggestion...............................................................Passed
- hook id: suggestion
- duration: 0s
To bypass pre-commit hooks, add --no-verify to git commit.
```
**note:**
clang-format is enabled but currently excludes all csrc/**
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
693 lines
28 KiB
Python
693 lines
28 KiB
Python
# adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Copyright (c) 2024, Tri Dao.
|
|
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
|
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
|
# mypy: ignore-errors
|
|
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from vllm.distributed import get_pcp_group
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
|
|
|
|
|
def causal_conv1d_ref(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
initial_states: torch.Tensor | None = None,
|
|
return_final_states: bool = False,
|
|
final_states_out: torch.Tensor | None = None,
|
|
activation: str | None = "silu",
|
|
):
|
|
"""
|
|
x: (batch, dim, seqlen)
|
|
weight: (dim, width)
|
|
bias: (dim,)
|
|
initial_states: (batch, dim, width - 1)
|
|
final_states_out: (batch, dim, width - 1)
|
|
out: (batch, dim, seqlen)
|
|
"""
|
|
if activation not in [None, "silu", "swish"]:
|
|
raise NotImplementedError("activation must be None, silu, or swish")
|
|
dtype_in = x.dtype
|
|
x = x.to(weight.dtype)
|
|
seqlen = x.shape[-1]
|
|
dim, width = weight.shape
|
|
|
|
if initial_states is None:
|
|
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
|
else:
|
|
x = torch.cat([initial_states, x], dim=-1)
|
|
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
|
out = out[..., :seqlen]
|
|
|
|
if return_final_states:
|
|
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1)
|
|
if final_states_out is not None:
|
|
final_states_out.copy_(final_states)
|
|
else:
|
|
final_states_out = final_states
|
|
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
|
return (out, None) if not return_final_states else (out, final_states_out)
|
|
|
|
|
|
def causal_conv1d_fn(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
activation: str | None = "silu",
|
|
conv_states: torch.Tensor | None = None,
|
|
has_initial_state: torch.Tensor | None = None,
|
|
cache_indices: torch.Tensor | None = None,
|
|
query_start_loc: torch.Tensor | None = None,
|
|
metadata: Any | None = None,
|
|
pad_slot_id: int = PAD_SLOT_ID,
|
|
):
|
|
"""
|
|
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
|
sequences are concatenated from left to right for varlen
|
|
weight: (dim, width)
|
|
bias: (dim,)
|
|
query_start_loc: (batch + 1) int32
|
|
The cumulative sequence lengths of the sequences in
|
|
the batch, used to index into sequence. prepended by 0.
|
|
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
|
x.shape=(dim,17)
|
|
cache_indices: (batch) int32
|
|
indicates the corresponding state index,
|
|
like so: conv_state = conv_states[cache_indices[batch_id]]
|
|
has_initial_state: (batch) bool
|
|
indicates whether should the kernel take the current state as initial
|
|
state for the calculations
|
|
conv_states: (...,dim,width - 1) itype
|
|
updated inplace if provided
|
|
activation: either None or "silu" or "swish"
|
|
pad_slot_id: int
|
|
if cache_indices is passed, lets the kernel identify padded
|
|
entries that will not be processed,
|
|
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
|
in this case, the kernel will not process entries at
|
|
indices 0 and 3
|
|
out: (batch, dim, seqlen)
|
|
"""
|
|
forward_context = get_forward_context()
|
|
num_decodes = 0
|
|
attn_metadata = forward_context.attn_metadata
|
|
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
|
attn_metadata = next(iter(attn_metadata.values()), None)
|
|
if attn_metadata is not None:
|
|
num_decodes = attn_metadata.num_decodes
|
|
|
|
if activation not in [None, "silu", "swish"]:
|
|
raise NotImplementedError("activation must be None, silu, or swish")
|
|
if x.stride(-1) != 1:
|
|
x = x.contiguous()
|
|
bias = bias.contiguous() if bias is not None else None
|
|
|
|
out_ref = []
|
|
out_ref_b = []
|
|
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
|
seqlens = seqlens.tolist()
|
|
splits = torch.split(x, seqlens, dim=-1)
|
|
width = weight.shape[1]
|
|
last_width_prefill_x = extract_last_width(x, query_start_loc[num_decodes:], conv_states.shape[-1])
|
|
|
|
if get_pcp_group().world_size > 1:
|
|
all_last_width_prefill_x = get_pcp_group().all_gather(last_width_prefill_x.unsqueeze(0).contiguous(), 0)
|
|
pcp_rank = get_pcp_group().rank_in_group
|
|
if pcp_rank > 0:
|
|
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[pcp_rank - 1, ...]
|
|
|
|
for i in range(len(seqlens)):
|
|
x_s = splits[i]
|
|
if cache_indices[i] == PAD_SLOT_ID:
|
|
continue
|
|
out_ref_b.append(
|
|
causal_conv1d_ref(
|
|
x_s,
|
|
weight,
|
|
bias,
|
|
activation=activation,
|
|
return_final_states=True,
|
|
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
|
|
initial_states=conv_states[cache_indices[i]][..., : (width - 1)],
|
|
)
|
|
)
|
|
|
|
if get_pcp_group().world_size > 1:
|
|
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[-1, ...]
|
|
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
|
out_ref_tensor = torch.cat(out_ref, dim=0)
|
|
return out_ref_tensor
|
|
|
|
|
|
def extract_last_width(x, start_loc, width):
|
|
end_loc = start_loc[1:]
|
|
offsets = torch.arange(width, device=x.device)
|
|
indices = end_loc.unsqueeze(1) - width + offsets.unsqueeze(0) # (num_seqs, width)
|
|
|
|
return x[:, indices].permute(1, 0, 2)
|
|
|
|
|
|
@triton.jit
|
|
def _causal_conv1d_update_kernel_npu_tiled(
|
|
# Pointers
|
|
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
|
w_ptr, # (dim, width)
|
|
bias_ptr,
|
|
conv_state_ptr, # (num_cache_lines, dim, state_len)
|
|
conv_state_indices_ptr,
|
|
num_accepted_tokens_ptr,
|
|
query_start_loc_ptr, # (batch + 1)
|
|
block_idx_last_scheduled_token, # (batch,)
|
|
initial_state_idx, # (batch,)
|
|
o_ptr, # same shape as x_ptr
|
|
batch: tl.int32,
|
|
dim: tl.constexpr,
|
|
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
|
state_len: tl.constexpr, # effective state_len computed in wrapper
|
|
num_cache_lines: tl.constexpr,
|
|
# Strides
|
|
stride_x_seq: tl.constexpr,
|
|
stride_x_dim: tl.constexpr,
|
|
stride_x_token: tl.constexpr,
|
|
stride_w_dim: tl.constexpr,
|
|
stride_w_width: tl.constexpr,
|
|
stride_conv_state_seq: tl.constexpr,
|
|
stride_conv_state_dim: tl.constexpr,
|
|
stride_conv_state_tok: tl.constexpr,
|
|
stride_state_indices: tl.constexpr,
|
|
stride_o_seq: tl.constexpr,
|
|
stride_o_dim: tl.constexpr,
|
|
stride_o_token: tl.constexpr,
|
|
# others
|
|
pad_slot_id: tl.constexpr,
|
|
# Meta
|
|
HAS_BIAS: tl.constexpr,
|
|
KERNEL_WIDTH: tl.constexpr, # <= 6
|
|
SILU_ACTIVATION: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
IS_APC_ENABLED: tl.constexpr,
|
|
IS_SPEC_DECODING: tl.constexpr,
|
|
NP2_STATELEN: tl.constexpr,
|
|
USE_PAD_SLOT: tl.constexpr,
|
|
# tiling
|
|
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
|
|
B_TILE: tl.constexpr, # batch tile
|
|
T_CHUNK: tl.constexpr, # token chunk for state update
|
|
):
|
|
# program ids
|
|
pid_b = tl.program_id(0) # batch-tile id
|
|
pid_c = tl.program_id(1) # channel-tile id
|
|
|
|
# channel indices for this program
|
|
idx_feats = pid_c * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N]
|
|
mask_w = idx_feats < dim
|
|
|
|
# preload weights once per program (shared by B_TILE sequences)
|
|
w_base = w_ptr + idx_feats * stride_w_dim
|
|
# define to avoid "undefined" in branches
|
|
w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
if KERNEL_WIDTH >= 1:
|
|
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 2:
|
|
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 3:
|
|
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 4:
|
|
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 5:
|
|
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 6:
|
|
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
|
|
# bias vector once per program
|
|
if HAS_BIAS:
|
|
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32)
|
|
else:
|
|
acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
|
|
# token index vector for chunked copy
|
|
tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK]
|
|
|
|
# process B_TILE sequences inside the same program instance
|
|
for bi in tl.static_range(0, B_TILE):
|
|
b = pid_b * B_TILE + bi # scalar tl.int32
|
|
lane_active = b < batch # scalar predicate
|
|
|
|
# -------------------------
|
|
# APC mapping (optional)
|
|
# -------------------------
|
|
if IS_APC_ENABLED:
|
|
conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32)
|
|
current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32)
|
|
else:
|
|
conv_state_init = tl.full((), 0, tl.int32)
|
|
current_last_index = tl.full((), 0, tl.int32)
|
|
|
|
# input cache line
|
|
conv_states_input_coord = tl.load(
|
|
conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0
|
|
).to(tl.int64)
|
|
|
|
if USE_PAD_SLOT:
|
|
lane_active = lane_active & (conv_states_input_coord != pad_slot_id)
|
|
|
|
# -------------------------
|
|
# varlen (optional): revise seqlen_run and state_len_run like original kernel does
|
|
# -------------------------
|
|
if IS_VARLEN:
|
|
qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64)
|
|
qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64)
|
|
seqlen_run = (qe - qs).to(tl.int32)
|
|
# revise effective state_len for shorter sequences (same formula as original)
|
|
state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32)
|
|
x_offset = (qs * stride_x_token).to(tl.int64)
|
|
o_offset = (qs * stride_o_token).to(tl.int64)
|
|
else:
|
|
seqlen_run = tl.full((), seqlen, tl.int32)
|
|
state_len_run = tl.full((), state_len, tl.int32)
|
|
x_offset = (b * stride_x_seq).to(tl.int64)
|
|
o_offset = (b * stride_o_seq).to(tl.int64)
|
|
|
|
# empty sequence -> skip (avoid early return because other lanes in tile)
|
|
lane_active = lane_active & (seqlen_run > 0)
|
|
|
|
# -------------------------
|
|
# spec decoding offset (optional)
|
|
# -------------------------
|
|
if IS_SPEC_DECODING:
|
|
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1
|
|
shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode
|
|
else:
|
|
conv_state_token_offset = tl.full((), 0, tl.int64)
|
|
shift = seqlen_run # normal mode shift by seqlen
|
|
|
|
# -------------------------
|
|
# STEP 1: read initial history cols BEFORE state update (out==x safe)
|
|
# -------------------------
|
|
conv_states_base = (
|
|
conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
|
)
|
|
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
|
|
|
# define history vectors as zeros then load conditionally
|
|
col0 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col1 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col2 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col3 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col4 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
if KERNEL_WIDTH >= 2:
|
|
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 3:
|
|
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 4:
|
|
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 5:
|
|
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 6:
|
|
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
|
|
# -------------------------
|
|
# STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block)
|
|
# Semantics: conv_state <- concat(old_state, x)[-state_len_run:].
|
|
# - If seqlen_run >= state_len_run: dst[:] = x[seqlen_run - state_len_run : seqlen_run]
|
|
# - Else: keep = state_len_run - seqlen_run,
|
|
# dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run]
|
|
# -------------------------
|
|
# output cache line
|
|
conv_states_offset = tl.load(
|
|
conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0
|
|
).to(tl.int64)
|
|
|
|
use_shift = seqlen_run < state_len_run
|
|
use_tail = seqlen_run >= state_len_run
|
|
|
|
zero_i32 = tl.full((), 0, tl.int32)
|
|
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32)
|
|
tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32)
|
|
|
|
# base pointers
|
|
state_src_base = (
|
|
conv_state_ptr
|
|
+ conv_states_input_coord * stride_conv_state_seq
|
|
+ conv_state_token_offset * stride_conv_state_tok
|
|
+ idx_feats * stride_conv_state_dim
|
|
)
|
|
state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
|
|
|
x_base = x_ptr + x_offset + idx_feats * stride_x_dim
|
|
|
|
# A) shift old state into dst[0:keep_shift) (only when seqlen_run < state_len_run)
|
|
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
|
|
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
|
src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK]
|
|
m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run)
|
|
m = (
|
|
(lane_active & m_tok)[:, None]
|
|
& mask_w[None, :]
|
|
& (conv_states_input_coord < num_cache_lines)
|
|
& (conv_states_offset < num_cache_lines)
|
|
)
|
|
|
|
src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok
|
|
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
|
vals = tl.load(src_ptrs, mask=m, other=0.0)
|
|
tl.store(dst_ptrs, vals, mask=m)
|
|
|
|
# B) append x into dst[keep_shift : keep_shift+seqlen_run) (only when seqlen_run < state_len_run)
|
|
for t0 in tl.static_range(0, seqlen, T_CHUNK):
|
|
x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
|
dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK]
|
|
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run)
|
|
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
|
|
|
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
|
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
|
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
|
tl.store(dst_ptrs, x_vals, mask=m)
|
|
|
|
# C) if seqlen_run >= state_len_run, overwrite dst with the tail of x
|
|
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
|
|
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
|
x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK]
|
|
m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run)
|
|
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
|
|
|
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
|
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
|
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
|
tl.store(dst_ptrs, x_vals, mask=m)
|
|
|
|
# -------------------------
|
|
# STEP 3/4/5: causal conv1d (+ optional SiLU) and store output
|
|
# This is original STEP3~5, but per-lane and without debug_barrier.
|
|
# -------------------------
|
|
x_base_1d = x_base
|
|
o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim
|
|
|
|
# accumulator preload (bias)
|
|
acc_preload = acc_bias
|
|
|
|
# compute each token; keep tl.range so varlen can use seqlen_run as runtime trip count (like original)
|
|
for idx_token in tl.range(seqlen_run):
|
|
acc = acc_preload
|
|
|
|
# same selection logic as original (unrolled by KERNEL_WIDTH)
|
|
matrix_w = w_col0
|
|
matrix_x = col0
|
|
for j in tl.static_range(KERNEL_WIDTH):
|
|
if KERNEL_WIDTH == 1:
|
|
# only x[t] * w0
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
matrix_w = w_col0
|
|
elif KERNEL_WIDTH == 2:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 3:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 4:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
matrix_x = col2
|
|
elif j == 3:
|
|
matrix_w = w_col3
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 5:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
matrix_x = col2
|
|
elif j == 3:
|
|
matrix_w = w_col3
|
|
matrix_x = col3
|
|
elif j == 4:
|
|
matrix_w = w_col4
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 6:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
matrix_x = col2
|
|
elif j == 3:
|
|
matrix_w = w_col3
|
|
matrix_x = col3
|
|
elif j == 4:
|
|
matrix_w = w_col4
|
|
matrix_x = col4
|
|
elif j == 5:
|
|
matrix_w = w_col5
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
|
|
acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N]
|
|
|
|
# roll history window
|
|
if KERNEL_WIDTH == 2:
|
|
col0 = matrix_x
|
|
elif KERNEL_WIDTH == 3:
|
|
col0 = col1
|
|
col1 = matrix_x
|
|
elif KERNEL_WIDTH == 4:
|
|
col0 = col1
|
|
col1 = col2
|
|
col2 = matrix_x
|
|
elif KERNEL_WIDTH == 5:
|
|
col0 = col1
|
|
col1 = col2
|
|
col2 = col3
|
|
col3 = matrix_x
|
|
elif KERNEL_WIDTH == 6:
|
|
col0 = col1
|
|
col1 = col2
|
|
col2 = col3
|
|
col3 = col4
|
|
col4 = matrix_x
|
|
|
|
if SILU_ACTIVATION:
|
|
acc = acc / (1.0 + tl.exp(-acc))
|
|
|
|
# store output
|
|
o_ptrs = o_base_1d + idx_token * stride_o_token
|
|
tl.store(o_ptrs, acc, mask=lane_active & mask_w)
|
|
|
|
|
|
def causal_conv1d_update_npu(
|
|
x: torch.Tensor,
|
|
conv_state: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
activation: bool | str | None = None,
|
|
conv_state_indices: torch.Tensor | None = None,
|
|
num_accepted_tokens: torch.Tensor | None = None,
|
|
query_start_loc: torch.Tensor | None = None,
|
|
max_query_len: int = -1,
|
|
pad_slot_id: int = PAD_SLOT_ID,
|
|
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
|
initial_state_idx: torch.Tensor | None = None,
|
|
validate_data=False,
|
|
):
|
|
"""
|
|
x: Input tensor which can take the following shapes:
|
|
|
|
- `[batch, dim]` - single token prediction
|
|
- `[batch, dim, seqlen]` - single or multiple tokens prediction
|
|
- `[num_tokens, dim]` - continuous batching, where num_tokens is
|
|
the total tokens of all sequences in that batch
|
|
|
|
conv_state: (..., dim, state_len), where state_len >= width - 1
|
|
weight: (dim, width)
|
|
bias: (dim,)
|
|
conv_state_indices: (batch,), dtype int32
|
|
If not None, the conv_state is a larger tensor along the batch dim,
|
|
and we are selecting the batch coords specified by conv_state_indices.
|
|
Useful for a continuous batching scenario.
|
|
block_idx_last_scheduled_token: (batch,), dtype int32
|
|
The pointer into conv_state_indices, where the last cache block to be filled is located.
|
|
initial_state_idx: (batch,), dtype int32
|
|
The pointer into conv_state_indices, where the cache block containing the initial state is located.
|
|
num_accepted_tokens: (batch,), dtype int32
|
|
If not None, it indicates the number of accepted tokens for each
|
|
sequence in the batch.
|
|
This is used in speculative decoding, where the conv_state is updated
|
|
in a sliding window manner.
|
|
query_start_loc: (batch + 1,) int32
|
|
If not None, the inputs is given in a varlen fashion and this indicates
|
|
the starting index of each sequence in the batch.
|
|
max_query_len: int
|
|
If query_start_loc is not None, this indicates the maximum query
|
|
length in the batch.
|
|
pad_slot_id: int
|
|
if conv_state_indices is passed, lets the kernel identify padded
|
|
entries that will not be processed,
|
|
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
|
in this case, the kernel will not process entries at
|
|
indices 0 and 3
|
|
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
|
"""
|
|
weight = weight.transpose(0, 1).contiguous()
|
|
conv_state = conv_state.transpose(1, 2).contiguous()
|
|
if validate_data:
|
|
assert pad_slot_id is not None
|
|
assert x.stride(1) == 1
|
|
if isinstance(activation, bool):
|
|
activation = "silu" if activation is True else None
|
|
elif activation is not None:
|
|
assert activation in ["silu", "swish"]
|
|
|
|
original_x_dtype = x.dtype
|
|
x = x.to(conv_state.dtype)
|
|
unsqueeze = query_start_loc is None and x.dim() == 2
|
|
if unsqueeze:
|
|
# make it (batch, dim, seqlen) with seqlen == 1
|
|
x = x.unsqueeze(1)
|
|
|
|
if query_start_loc is None:
|
|
batch, seqlen, dim = x.shape
|
|
else:
|
|
assert conv_state_indices is not None
|
|
batch = conv_state_indices.size(0)
|
|
dim = x.size(1)
|
|
seqlen = max_query_len
|
|
|
|
width, _ = weight.shape
|
|
num_cache_lines, state_len_total, _ = conv_state.size()
|
|
|
|
# overwrite-on-x strategy same as original
|
|
out = x
|
|
|
|
stride_w_width, stride_w_dim = weight.stride()
|
|
if query_start_loc is None:
|
|
stride_x_seq, stride_x_token, stride_x_dim = x.stride()
|
|
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
|
|
else:
|
|
stride_x_token, stride_x_dim = x.stride()
|
|
stride_x_seq = 0
|
|
stride_o_token, stride_o_dim = out.stride()
|
|
stride_o_seq = 0
|
|
|
|
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride()
|
|
stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
|
|
|
# effective state_len exactly as original
|
|
if num_accepted_tokens is not None:
|
|
eff_state_len = width - 1 + (seqlen - 1)
|
|
else:
|
|
eff_state_len = width - 1
|
|
np2_statelen = triton.next_power_of_2(eff_state_len)
|
|
|
|
# -------- tiling heuristic--------
|
|
# keep program count around ~[80..160]
|
|
# vector core 40
|
|
# TODO: use driver to get the vector core num
|
|
CORE_HINT = 40
|
|
# channel tile: 512 when dim large (reduce tasks), else 256
|
|
block_n = 512 if dim >= 512 else 256
|
|
g = triton.cdiv(dim, block_n)
|
|
target = 2 * CORE_HINT # ~80
|
|
b_tile_raw = max(1, (batch * g + target - 1) // target)
|
|
# clamp to small set
|
|
if b_tile_raw <= 1:
|
|
b_tile = 1
|
|
elif b_tile_raw <= 2:
|
|
b_tile = 2
|
|
elif b_tile_raw <= 4:
|
|
b_tile = 4
|
|
else:
|
|
b_tile = 8
|
|
|
|
# token chunk based on block_n (32KB UB idea); conservative
|
|
t_chunk = 1 if block_n == 512 else 48
|
|
|
|
def grid(META):
|
|
return (
|
|
triton.cdiv(batch, META["B_TILE"]),
|
|
triton.cdiv(dim, META["BLOCK_N"]),
|
|
)
|
|
|
|
_causal_conv1d_update_kernel_npu_tiled[grid](
|
|
x,
|
|
weight,
|
|
bias,
|
|
conv_state,
|
|
conv_state_indices,
|
|
num_accepted_tokens,
|
|
query_start_loc,
|
|
block_idx_last_scheduled_token,
|
|
initial_state_idx,
|
|
out,
|
|
batch,
|
|
dim,
|
|
seqlen,
|
|
eff_state_len,
|
|
num_cache_lines,
|
|
stride_x_seq,
|
|
stride_x_dim,
|
|
stride_x_token,
|
|
stride_w_dim,
|
|
stride_w_width,
|
|
stride_istate_seq,
|
|
stride_istate_dim,
|
|
stride_istate_token,
|
|
stride_state_indices,
|
|
stride_o_seq,
|
|
stride_o_dim,
|
|
stride_o_token,
|
|
pad_slot_id,
|
|
HAS_BIAS=bias is not None,
|
|
KERNEL_WIDTH=width,
|
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
|
IS_VARLEN=query_start_loc is not None,
|
|
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
|
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
|
NP2_STATELEN=np2_statelen,
|
|
USE_PAD_SLOT=pad_slot_id is not None,
|
|
BLOCK_N=block_n,
|
|
B_TILE=b_tile,
|
|
T_CHUNK=t_chunk,
|
|
)
|
|
|
|
if unsqueeze:
|
|
out = out.squeeze(1)
|
|
return out.to(original_x_dtype)
|