<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? - Please clarify why the changes are needed. For instance, the use case and bug description. Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. main branch:https://github.com/vllm-project/vllm-ascend/pull/7483 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: cvSoldier <610496306@qq.com>
706 lines
28 KiB
Python
706 lines
28 KiB
Python
# adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Copyright (c) 2024, Tri Dao.
|
|
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
|
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
|
# mypy: ignore-errors
|
|
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from vllm.distributed import get_pcp_group
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
|
|
|
|
|
def causal_conv1d_ref(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
initial_states: torch.Tensor | None = None,
|
|
return_final_states: bool = False,
|
|
final_states_out: torch.Tensor | None = None,
|
|
activation: str | None = "silu",
|
|
):
|
|
"""
|
|
x: (batch, dim, seqlen)
|
|
weight: (dim, width)
|
|
bias: (dim,)
|
|
initial_states: (batch, dim, width - 1)
|
|
final_states_out: (batch, dim, width - 1)
|
|
out: (batch, dim, seqlen)
|
|
"""
|
|
if activation not in [None, "silu", "swish"]:
|
|
raise NotImplementedError("activation must be None, silu, or swish")
|
|
dtype_in = x.dtype
|
|
x = x.to(weight.dtype)
|
|
seqlen = x.shape[-1]
|
|
dim, width = weight.shape
|
|
|
|
if initial_states is None:
|
|
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
|
else:
|
|
x = torch.cat([initial_states, x], dim=-1)
|
|
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
|
out = out[..., :seqlen]
|
|
|
|
if return_final_states:
|
|
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1)
|
|
if final_states_out is not None:
|
|
final_states_out.copy_(final_states)
|
|
else:
|
|
final_states_out = final_states
|
|
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
|
return (out, None) if not return_final_states else (out, final_states_out)
|
|
|
|
|
|
def causal_conv1d_fn(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
activation: str | None = "silu",
|
|
conv_states: torch.Tensor | None = None,
|
|
has_initial_state: torch.Tensor | None = None,
|
|
cache_indices: torch.Tensor | None = None,
|
|
query_start_loc: torch.Tensor | None = None,
|
|
metadata: Any | None = None,
|
|
pad_slot_id: int = PAD_SLOT_ID,
|
|
):
|
|
"""
|
|
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
|
sequences are concatenated from left to right for varlen
|
|
weight: (dim, width)
|
|
bias: (dim,)
|
|
query_start_loc: (batch + 1) int32
|
|
The cumulative sequence lengths of the sequences in
|
|
the batch, used to index into sequence. prepended by 0.
|
|
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
|
x.shape=(dim,17)
|
|
cache_indices: (batch) int32
|
|
indicates the corresponding state index,
|
|
like so: conv_state = conv_states[cache_indices[batch_id]]
|
|
has_initial_state: (batch) bool
|
|
indicates whether should the kernel take the current state as initial
|
|
state for the calculations
|
|
conv_states: (...,dim,width - 1) itype
|
|
updated inplace if provided
|
|
activation: either None or "silu" or "swish"
|
|
pad_slot_id: int
|
|
if cache_indices is passed, lets the kernel identify padded
|
|
entries that will not be processed,
|
|
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
|
in this case, the kernel will not process entries at
|
|
indices 0 and 3
|
|
out: (batch, dim, seqlen)
|
|
"""
|
|
forward_context = get_forward_context()
|
|
num_decodes = 0
|
|
attn_metadata = forward_context.attn_metadata
|
|
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
|
attn_metadata = next(iter(attn_metadata.values()), None)
|
|
if attn_metadata is not None:
|
|
num_decodes = attn_metadata.num_decodes
|
|
|
|
if activation not in [None, "silu", "swish"]:
|
|
raise NotImplementedError("activation must be None, silu, or swish")
|
|
if x.stride(-1) != 1:
|
|
x = x.contiguous()
|
|
bias = bias.contiguous() if bias is not None else None
|
|
|
|
out_ref = []
|
|
out_ref_b = []
|
|
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
|
seqlens = seqlens.tolist()
|
|
splits = torch.split(x, seqlens, dim=-1)
|
|
width = weight.shape[1]
|
|
last_width_prefill_x = extract_last_width(x, query_start_loc[num_decodes:], conv_states.shape[-1])
|
|
|
|
if get_pcp_group().world_size > 1:
|
|
all_last_width_prefill_x = get_pcp_group().all_gather(last_width_prefill_x.unsqueeze(0).contiguous(), 0)
|
|
pcp_rank = get_pcp_group().rank_in_group
|
|
if pcp_rank > 0:
|
|
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[pcp_rank - 1, ...]
|
|
|
|
for i in range(len(seqlens)):
|
|
x_s = splits[i]
|
|
if cache_indices[i] == PAD_SLOT_ID:
|
|
continue
|
|
out_ref_b.append(
|
|
causal_conv1d_ref(
|
|
x_s,
|
|
weight,
|
|
bias,
|
|
activation=activation,
|
|
return_final_states=True,
|
|
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
|
|
initial_states=conv_states[cache_indices[i]][..., : (width - 1)],
|
|
)
|
|
)
|
|
|
|
if get_pcp_group().world_size > 1:
|
|
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[-1, ...]
|
|
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
|
out_ref_tensor = torch.cat(out_ref, dim=0)
|
|
return out_ref_tensor
|
|
|
|
|
|
def extract_last_width(x, start_loc, width):
|
|
end_loc = start_loc[1:]
|
|
offsets = torch.arange(width, device=x.device)
|
|
indices = end_loc.unsqueeze(1) - width + offsets.unsqueeze(0) # (num_seqs, width)
|
|
|
|
return x[:, indices].permute(1, 0, 2)
|
|
|
|
|
|
@triton.jit(
|
|
do_not_specialize=[
|
|
"batch",
|
|
"state_len",
|
|
"num_cache_lines",
|
|
"stride_x_seq",
|
|
"stride_x_token",
|
|
"stride_conv_state_seq",
|
|
"stride_conv_state_tok",
|
|
"stride_state_indices",
|
|
"stride_o_seq",
|
|
"stride_o_token",
|
|
]
|
|
)
|
|
def _causal_conv1d_update_kernel_npu_tiled(
|
|
# Pointers
|
|
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
|
w_ptr, # (dim, width)
|
|
bias_ptr,
|
|
conv_state_ptr, # (num_cache_lines, dim, state_len)
|
|
conv_state_indices_ptr,
|
|
num_accepted_tokens_ptr,
|
|
query_start_loc_ptr, # (batch + 1)
|
|
block_idx_last_scheduled_token, # (batch,)
|
|
initial_state_idx, # (batch,)
|
|
o_ptr, # same shape as x_ptr
|
|
batch: tl.int32,
|
|
dim: tl.constexpr,
|
|
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
|
state_len, # effective state_len computed in wrapper
|
|
num_cache_lines,
|
|
# Strides
|
|
stride_x_seq,
|
|
stride_x_dim: tl.constexpr,
|
|
stride_x_token,
|
|
stride_w_dim: tl.constexpr,
|
|
stride_w_width: tl.constexpr,
|
|
stride_conv_state_seq,
|
|
stride_conv_state_dim: tl.constexpr,
|
|
stride_conv_state_tok,
|
|
stride_state_indices,
|
|
stride_o_seq,
|
|
stride_o_dim: tl.constexpr,
|
|
stride_o_token,
|
|
# others
|
|
pad_slot_id: tl.constexpr,
|
|
# Meta
|
|
HAS_BIAS: tl.constexpr,
|
|
KERNEL_WIDTH: tl.constexpr, # <= 6
|
|
SILU_ACTIVATION: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
IS_APC_ENABLED: tl.constexpr,
|
|
IS_SPEC_DECODING: tl.constexpr,
|
|
NP2_STATELEN: tl.constexpr,
|
|
USE_PAD_SLOT: tl.constexpr,
|
|
# tiling
|
|
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
|
|
B_TILE: tl.constexpr, # batch tile
|
|
T_CHUNK: tl.constexpr, # token chunk for state update
|
|
):
|
|
# program ids
|
|
pid_b = tl.program_id(0) # batch-tile id
|
|
pid_c = tl.program_id(1) # channel-tile id
|
|
|
|
# channel indices for this program
|
|
idx_feats = pid_c * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N]
|
|
mask_w = idx_feats < dim
|
|
|
|
# preload weights once per program (shared by B_TILE sequences)
|
|
w_base = w_ptr + idx_feats * stride_w_dim
|
|
# define to avoid "undefined" in branches
|
|
w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
if KERNEL_WIDTH >= 1:
|
|
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 2:
|
|
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 3:
|
|
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 4:
|
|
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 5:
|
|
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
if KERNEL_WIDTH >= 6:
|
|
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
|
|
|
# bias vector once per program
|
|
if HAS_BIAS:
|
|
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32)
|
|
else:
|
|
acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
|
|
# token index vector for chunked copy
|
|
tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK]
|
|
|
|
# process B_TILE sequences inside the same program instance
|
|
for bi in tl.static_range(0, B_TILE):
|
|
b = pid_b * B_TILE + bi # scalar tl.int32
|
|
lane_active = b < batch # scalar predicate
|
|
|
|
# -------------------------
|
|
# APC mapping (optional)
|
|
# -------------------------
|
|
if IS_APC_ENABLED:
|
|
conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32)
|
|
current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32)
|
|
else:
|
|
conv_state_init = tl.full((), 0, tl.int32)
|
|
current_last_index = tl.full((), 0, tl.int32)
|
|
|
|
# input cache line
|
|
conv_states_input_coord = tl.load(
|
|
conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0
|
|
).to(tl.int64)
|
|
|
|
if USE_PAD_SLOT:
|
|
lane_active = lane_active & (conv_states_input_coord != pad_slot_id)
|
|
|
|
# -------------------------
|
|
# varlen (optional): revise seqlen_run and state_len_run like original kernel does
|
|
# -------------------------
|
|
if IS_VARLEN:
|
|
qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64)
|
|
qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64)
|
|
seqlen_run = (qe - qs).to(tl.int32)
|
|
# revise effective state_len for shorter sequences (same formula as original)
|
|
state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32)
|
|
x_offset = (qs * stride_x_token).to(tl.int64)
|
|
o_offset = (qs * stride_o_token).to(tl.int64)
|
|
else:
|
|
seqlen_run = tl.full((), seqlen, tl.int32)
|
|
state_len_run = tl.full((), state_len, tl.int32)
|
|
x_offset = (b * stride_x_seq).to(tl.int64)
|
|
o_offset = (b * stride_o_seq).to(tl.int64)
|
|
|
|
# empty sequence -> skip (avoid early return because other lanes in tile)
|
|
lane_active = lane_active & (seqlen_run > 0)
|
|
|
|
# -------------------------
|
|
# spec decoding offset (optional)
|
|
# -------------------------
|
|
if IS_SPEC_DECODING:
|
|
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1
|
|
shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode
|
|
else:
|
|
conv_state_token_offset = tl.full((), 0, tl.int64)
|
|
shift = seqlen_run # normal mode shift by seqlen
|
|
|
|
# -------------------------
|
|
# STEP 1: read initial history cols BEFORE state update (out==x safe)
|
|
# -------------------------
|
|
conv_states_base = (
|
|
conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
|
)
|
|
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
|
|
|
# define history vectors as zeros then load conditionally
|
|
col0 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col1 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col2 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col3 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
col4 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
|
if KERNEL_WIDTH >= 2:
|
|
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 3:
|
|
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 4:
|
|
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 5:
|
|
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
if KERNEL_WIDTH >= 6:
|
|
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
|
tl.float16
|
|
)
|
|
|
|
# -------------------------
|
|
# STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block)
|
|
# Semantics: conv_state <- concat(old_state, x)[-state_len_run:].
|
|
# - If seqlen_run >= state_len_run: dst[:] = x[seqlen_run - state_len_run : seqlen_run]
|
|
# - Else: keep = state_len_run - seqlen_run,
|
|
# dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run]
|
|
# -------------------------
|
|
# output cache line
|
|
conv_states_offset = tl.load(
|
|
conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0
|
|
).to(tl.int64)
|
|
|
|
use_shift = seqlen_run < state_len_run
|
|
use_tail = seqlen_run >= state_len_run
|
|
|
|
zero_i32 = tl.full((), 0, tl.int32)
|
|
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32)
|
|
tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32)
|
|
|
|
# base pointers
|
|
state_src_base = (
|
|
conv_state_ptr
|
|
+ conv_states_input_coord * stride_conv_state_seq
|
|
+ conv_state_token_offset * stride_conv_state_tok
|
|
+ idx_feats * stride_conv_state_dim
|
|
)
|
|
state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
|
|
|
x_base = x_ptr + x_offset + idx_feats * stride_x_dim
|
|
|
|
# A) shift old state into dst[0:keep_shift) (only when seqlen_run < state_len_run)
|
|
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
|
|
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
|
src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK]
|
|
m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run)
|
|
m = (
|
|
(lane_active & m_tok)[:, None]
|
|
& mask_w[None, :]
|
|
& (conv_states_input_coord < num_cache_lines)
|
|
& (conv_states_offset < num_cache_lines)
|
|
)
|
|
|
|
src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok
|
|
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
|
vals = tl.load(src_ptrs, mask=m, other=0.0)
|
|
tl.store(dst_ptrs, vals, mask=m)
|
|
|
|
# B) append x into dst[keep_shift : keep_shift+seqlen_run) (only when seqlen_run < state_len_run)
|
|
for t0 in tl.static_range(0, seqlen, T_CHUNK):
|
|
x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
|
dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK]
|
|
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run)
|
|
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
|
|
|
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
|
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
|
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
|
tl.store(dst_ptrs, x_vals, mask=m)
|
|
|
|
# C) if seqlen_run >= state_len_run, overwrite dst with the tail of x
|
|
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
|
|
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
|
x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK]
|
|
m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run)
|
|
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
|
|
|
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
|
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
|
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
|
tl.store(dst_ptrs, x_vals, mask=m)
|
|
|
|
# -------------------------
|
|
# STEP 3/4/5: causal conv1d (+ optional SiLU) and store output
|
|
# This is original STEP3~5, but per-lane and without debug_barrier.
|
|
# -------------------------
|
|
x_base_1d = x_base
|
|
o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim
|
|
|
|
# accumulator preload (bias)
|
|
acc_preload = acc_bias
|
|
|
|
# compute each token; keep tl.range so varlen can use seqlen_run as runtime trip count (like original)
|
|
for idx_token in tl.range(seqlen_run):
|
|
acc = acc_preload
|
|
|
|
# same selection logic as original (unrolled by KERNEL_WIDTH)
|
|
matrix_w = w_col0
|
|
matrix_x = col0
|
|
for j in tl.static_range(KERNEL_WIDTH):
|
|
if KERNEL_WIDTH == 1:
|
|
# only x[t] * w0
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
matrix_w = w_col0
|
|
elif KERNEL_WIDTH == 2:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 3:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 4:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
matrix_x = col2
|
|
elif j == 3:
|
|
matrix_w = w_col3
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 5:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
matrix_x = col2
|
|
elif j == 3:
|
|
matrix_w = w_col3
|
|
matrix_x = col3
|
|
elif j == 4:
|
|
matrix_w = w_col4
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
elif KERNEL_WIDTH == 6:
|
|
if j == 1:
|
|
matrix_w = w_col1
|
|
matrix_x = col1
|
|
elif j == 2:
|
|
matrix_w = w_col2
|
|
matrix_x = col2
|
|
elif j == 3:
|
|
matrix_w = w_col3
|
|
matrix_x = col3
|
|
elif j == 4:
|
|
matrix_w = w_col4
|
|
matrix_x = col4
|
|
elif j == 5:
|
|
matrix_w = w_col5
|
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
|
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
|
|
|
acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N]
|
|
|
|
# roll history window
|
|
if KERNEL_WIDTH == 2:
|
|
col0 = matrix_x
|
|
elif KERNEL_WIDTH == 3:
|
|
col0 = col1
|
|
col1 = matrix_x
|
|
elif KERNEL_WIDTH == 4:
|
|
col0 = col1
|
|
col1 = col2
|
|
col2 = matrix_x
|
|
elif KERNEL_WIDTH == 5:
|
|
col0 = col1
|
|
col1 = col2
|
|
col2 = col3
|
|
col3 = matrix_x
|
|
elif KERNEL_WIDTH == 6:
|
|
col0 = col1
|
|
col1 = col2
|
|
col2 = col3
|
|
col3 = col4
|
|
col4 = matrix_x
|
|
|
|
if SILU_ACTIVATION:
|
|
acc = acc / (1.0 + tl.exp(-acc))
|
|
|
|
# store output
|
|
o_ptrs = o_base_1d + idx_token * stride_o_token
|
|
tl.store(o_ptrs, acc, mask=lane_active & mask_w)
|
|
|
|
|
|
def causal_conv1d_update_npu(
|
|
x: torch.Tensor,
|
|
conv_state: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
activation: bool | str | None = None,
|
|
conv_state_indices: torch.Tensor | None = None,
|
|
num_accepted_tokens: torch.Tensor | None = None,
|
|
query_start_loc: torch.Tensor | None = None,
|
|
max_query_len: int = -1,
|
|
pad_slot_id: int = PAD_SLOT_ID,
|
|
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
|
initial_state_idx: torch.Tensor | None = None,
|
|
validate_data=False,
|
|
):
|
|
"""
|
|
x: Input tensor which can take the following shapes:
|
|
|
|
- `[batch, dim]` - single token prediction
|
|
- `[batch, dim, seqlen]` - single or multiple tokens prediction
|
|
- `[num_tokens, dim]` - continuous batching, where num_tokens is
|
|
the total tokens of all sequences in that batch
|
|
|
|
conv_state: (..., dim, state_len), where state_len >= width - 1
|
|
weight: (dim, width)
|
|
bias: (dim,)
|
|
conv_state_indices: (batch,), dtype int32
|
|
If not None, the conv_state is a larger tensor along the batch dim,
|
|
and we are selecting the batch coords specified by conv_state_indices.
|
|
Useful for a continuous batching scenario.
|
|
block_idx_last_scheduled_token: (batch,), dtype int32
|
|
The pointer into conv_state_indices, where the last cache block to be filled is located.
|
|
initial_state_idx: (batch,), dtype int32
|
|
The pointer into conv_state_indices, where the cache block containing the initial state is located.
|
|
num_accepted_tokens: (batch,), dtype int32
|
|
If not None, it indicates the number of accepted tokens for each
|
|
sequence in the batch.
|
|
This is used in speculative decoding, where the conv_state is updated
|
|
in a sliding window manner.
|
|
query_start_loc: (batch + 1,) int32
|
|
If not None, the inputs is given in a varlen fashion and this indicates
|
|
the starting index of each sequence in the batch.
|
|
max_query_len: int
|
|
If query_start_loc is not None, this indicates the maximum query
|
|
length in the batch.
|
|
pad_slot_id: int
|
|
if conv_state_indices is passed, lets the kernel identify padded
|
|
entries that will not be processed,
|
|
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
|
in this case, the kernel will not process entries at
|
|
indices 0 and 3
|
|
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
|
"""
|
|
weight = weight.transpose(0, 1).contiguous()
|
|
conv_state = conv_state.transpose(1, 2).contiguous()
|
|
if validate_data:
|
|
assert pad_slot_id is not None
|
|
assert x.stride(1) == 1
|
|
if isinstance(activation, bool):
|
|
activation = "silu" if activation is True else None
|
|
elif activation is not None:
|
|
assert activation in ["silu", "swish"]
|
|
|
|
original_x_dtype = x.dtype
|
|
x = x.to(conv_state.dtype)
|
|
unsqueeze = query_start_loc is None and x.dim() == 2
|
|
if unsqueeze:
|
|
# make it (batch, dim, seqlen) with seqlen == 1
|
|
x = x.unsqueeze(1)
|
|
|
|
if query_start_loc is None:
|
|
batch, seqlen, dim = x.shape
|
|
else:
|
|
assert conv_state_indices is not None
|
|
batch = conv_state_indices.size(0)
|
|
dim = x.size(1)
|
|
seqlen = max_query_len
|
|
|
|
width, _ = weight.shape
|
|
num_cache_lines, state_len_total, _ = conv_state.size()
|
|
|
|
# overwrite-on-x strategy same as original
|
|
out = x
|
|
|
|
stride_w_width, stride_w_dim = weight.stride()
|
|
if query_start_loc is None:
|
|
stride_x_seq, stride_x_token, stride_x_dim = x.stride()
|
|
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
|
|
else:
|
|
stride_x_token, stride_x_dim = x.stride()
|
|
stride_x_seq = 0
|
|
stride_o_token, stride_o_dim = out.stride()
|
|
stride_o_seq = 0
|
|
|
|
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride()
|
|
stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
|
|
|
# effective state_len exactly as original
|
|
if num_accepted_tokens is not None:
|
|
eff_state_len = width - 1 + (seqlen - 1)
|
|
else:
|
|
eff_state_len = width - 1
|
|
np2_statelen = triton.next_power_of_2(eff_state_len)
|
|
|
|
# -------- tiling heuristic--------
|
|
# keep program count around ~[80..160]
|
|
# vector core 40
|
|
# TODO: use driver to get the vector core num
|
|
CORE_HINT = 40
|
|
# channel tile: 512 when dim large (reduce tasks), else 256
|
|
block_n = 512 if dim >= 512 else 256
|
|
g = triton.cdiv(dim, block_n)
|
|
target = 2 * CORE_HINT # ~80
|
|
b_tile_raw = max(1, (batch * g + target - 1) // target)
|
|
# clamp to small set
|
|
if b_tile_raw <= 1:
|
|
b_tile = 1
|
|
elif b_tile_raw <= 2:
|
|
b_tile = 2
|
|
elif b_tile_raw <= 4:
|
|
b_tile = 4
|
|
else:
|
|
b_tile = 8
|
|
|
|
# token chunk based on block_n (32KB UB idea); conservative
|
|
t_chunk = 1 if block_n == 512 else 48
|
|
|
|
def grid(META):
|
|
return (
|
|
triton.cdiv(batch, META["B_TILE"]),
|
|
triton.cdiv(dim, META["BLOCK_N"]),
|
|
)
|
|
|
|
_causal_conv1d_update_kernel_npu_tiled[grid](
|
|
x,
|
|
weight,
|
|
bias,
|
|
conv_state,
|
|
conv_state_indices,
|
|
num_accepted_tokens,
|
|
query_start_loc,
|
|
block_idx_last_scheduled_token,
|
|
initial_state_idx,
|
|
out,
|
|
batch,
|
|
dim,
|
|
seqlen,
|
|
eff_state_len,
|
|
num_cache_lines,
|
|
stride_x_seq,
|
|
stride_x_dim,
|
|
stride_x_token,
|
|
stride_w_dim,
|
|
stride_w_width,
|
|
stride_istate_seq,
|
|
stride_istate_dim,
|
|
stride_istate_token,
|
|
stride_state_indices,
|
|
stride_o_seq,
|
|
stride_o_dim,
|
|
stride_o_token,
|
|
pad_slot_id,
|
|
HAS_BIAS=bias is not None,
|
|
KERNEL_WIDTH=width,
|
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
|
IS_VARLEN=query_start_loc is not None,
|
|
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
|
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
|
NP2_STATELEN=np2_statelen,
|
|
USE_PAD_SLOT=pad_slot_id is not None,
|
|
BLOCK_N=block_n,
|
|
B_TILE=b_tile,
|
|
T_CHUNK=t_chunk,
|
|
)
|
|
|
|
if unsqueeze:
|
|
out = out.squeeze(1)
|
|
return out.to(original_x_dtype)
|