# adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2024, Tri Dao. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # mypy: ignore-errors from typing import Any, Optional import torch import torch.nn.functional as F import triton import triton.language as tl from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore def causal_conv1d_ref( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, initial_states: Optional[torch.Tensor] = None, return_final_states: bool = False, final_states_out: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", ): """ x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1) out: (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") dtype_in = x.dtype x = x.to(weight.dtype) seqlen = x.shape[-1] dim, width = weight.shape if initial_states is None: out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( dtype_in) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: final_states_out = final_states out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) return (out, None) if not return_final_states else (out, final_states_out) def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", conv_states: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None, cache_indices: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None, metadata: Optional[Any] = None, pad_slot_id: int = PAD_SLOT_ID, ): """ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen weight: (dim, width) bias: (dim,) query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 indicates the corresponding state index, like so: conv_state = conv_states[cache_indices[batch_id]] has_initial_state: (batch) bool indicates whether should the kernel take the current state as initial state for the calculations conv_states: (...,dim,width - 1) itype updated inplace if provided activation: either None or "silu" or "swish" pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") if x.stride(-1) != 1: x = x.contiguous() bias = bias.contiguous() if bias is not None else None out_ref = [] out_ref_b = [] seqlens = query_start_loc[1:] - query_start_loc[:-1] seqlens = seqlens.tolist() splits = torch.split(x, seqlens, dim=-1) width = weight.shape[1] for i in range(len(seqlens)): x_s = splits[i] if cache_indices[i] == PAD_SLOT_ID: continue out_ref_b.append( causal_conv1d_ref( x_s, weight, bias, activation=activation, return_final_states=True, final_states_out=conv_states[cache_indices[i]][..., :( width - 1)].unsqueeze(0), initial_states=conv_states[cache_indices[i]][..., :(width - 1)] if has_initial_state[i] else None)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) out_ref_tensor = torch.cat(out_ref, dim=0) return out_ref_tensor @triton.jit def _causal_conv1d_update_kernel_npu_tiled( # Pointers x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen w_ptr, # (dim, width) bias_ptr, conv_state_ptr, # (num_cache_lines, dim, state_len) conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) o_ptr, # same shape as x_ptr batch: tl.int32, dim: tl.constexpr, seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen state_len: tl.constexpr, # effective state_len computed in wrapper num_cache_lines: tl.constexpr, # Strides stride_x_seq: tl.constexpr, stride_x_dim: tl.constexpr, stride_x_token: tl.constexpr, stride_w_dim: tl.constexpr, stride_w_width: tl.constexpr, stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, stride_state_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, # others pad_slot_id: tl.constexpr, # Meta HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, # <= 6 SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, IS_APC_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, # tiling BLOCK_N: tl.constexpr, # channel tile (C_TILE) B_TILE: tl.constexpr, # batch tile T_CHUNK: tl.constexpr, # token chunk for state update ): # program ids pid_b = tl.program_id(0) # batch-tile id pid_c = tl.program_id(1) # channel-tile id # channel indices for this program idx_feats = pid_c * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N] mask_w = idx_feats < dim # preload weights once per program (shared by B_TILE sequences) w_base = w_ptr + idx_feats * stride_w_dim # define to avoid "undefined" in branches w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32) w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32) w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32) w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32) w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32) w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32) if KERNEL_WIDTH >= 1: w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 2: w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 3: w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 4: w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 5: w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 6: w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) # bias vector once per program if HAS_BIAS: acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32) else: acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32) # token index vector for chunked copy tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK] # process B_TILE sequences inside the same program instance for bi in tl.static_range(0, B_TILE): b = pid_b * B_TILE + bi # scalar tl.int32 lane_active = b < batch # scalar predicate # ------------------------- # APC mapping (optional) # ------------------------- if IS_APC_ENABLED: conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32) current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32) else: conv_state_init = tl.full((), 0, tl.int32) current_last_index = tl.full((), 0, tl.int32) # input cache line conv_states_input_coord = tl.load(conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0).to(tl.int64) if USE_PAD_SLOT: lane_active = lane_active & (conv_states_input_coord != pad_slot_id) # ------------------------- # varlen (optional): revise seqlen_run and state_len_run like original kernel does # ------------------------- if IS_VARLEN: qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64) qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64) seqlen_run = (qe - qs).to(tl.int32) # revise effective state_len for shorter sequences (same formula as original) state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32) x_offset = (qs * stride_x_token).to(tl.int64) o_offset = (qs * stride_o_token).to(tl.int64) else: seqlen_run = tl.full((), seqlen, tl.int32) state_len_run = tl.full((), state_len, tl.int32) x_offset = (b * stride_x_seq).to(tl.int64) o_offset = (b * stride_o_seq).to(tl.int64) # empty sequence -> skip (avoid early return because other lanes in tile) lane_active = lane_active & (seqlen_run > 0) # ------------------------- # spec decoding offset (optional) # ------------------------- if IS_SPEC_DECODING: conv_state_token_offset = ( tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1) shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode else: conv_state_token_offset = tl.full((), 0, tl.int64) shift = seqlen_run # normal mode shift by seqlen # ------------------------- # STEP 1: read initial history cols BEFORE state update (out==x safe) # ------------------------- conv_states_base = (conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim) prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok # define history vectors as zeros then load conditionally col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16) col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16) col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16) col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16) col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16) if KERNEL_WIDTH >= 2: col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(tl.float16) if KERNEL_WIDTH >= 3: col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(tl.float16) if KERNEL_WIDTH >= 4: col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(tl.float16) if KERNEL_WIDTH >= 5: col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(tl.float16) if KERNEL_WIDTH >= 6: col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(tl.float16) # ------------------------- # STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block) # Semantics: conv_state <- concat(old_state, x)[-state_len_run:]. # - If seqlen_run >= state_len_run: dst[:] = x[seqlen_run - state_len_run : seqlen_run] # - Else: keep = state_len_run - seqlen_run, # dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run] # ------------------------- # output cache line conv_states_offset = tl.load(conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0).to(tl.int64) use_shift = (seqlen_run < state_len_run) use_tail = (seqlen_run >= state_len_run) zero_i32 = tl.full((), 0, tl.int32) keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32) tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32) # base pointers state_src_base = (conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + conv_state_token_offset * stride_conv_state_tok + idx_feats * stride_conv_state_dim) state_dst_base = (conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim) x_base = x_ptr + x_offset + idx_feats * stride_x_dim # A) shift old state into dst[0:keep_shift) (only when seqlen_run < state_len_run) for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK] m_tok = use_shift & (dst_tok < keep_shift) & ( src_tok < state_len_run) & (dst_tok < state_len_run) m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( conv_states_input_coord < num_cache_lines) & (conv_states_offset < num_cache_lines) src_ptrs = state_src_base[ None, :] + src_tok[:, None] * stride_conv_state_tok dst_ptrs = state_dst_base[ None, :] + dst_tok[:, None] * stride_conv_state_tok vals = tl.load(src_ptrs, mask=m, other=0.0) tl.store(dst_ptrs, vals, mask=m) # B) append x into dst[keep_shift : keep_shift+seqlen_run) (only when seqlen_run < state_len_run) for t0 in tl.static_range(0, seqlen, T_CHUNK): x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK] m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run) m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( conv_states_offset < num_cache_lines) x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token dst_ptrs = state_dst_base[ None, :] + dst_tok[:, None] * stride_conv_state_tok x_vals = tl.load(x_ptrs, mask=m, other=0.0) tl.store(dst_ptrs, x_vals, mask=m) # C) if seqlen_run >= state_len_run, overwrite dst with the tail of x for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK] m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( conv_states_offset < num_cache_lines) x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token dst_ptrs = state_dst_base[ None, :] + dst_tok[:, None] * stride_conv_state_tok x_vals = tl.load(x_ptrs, mask=m, other=0.0) tl.store(dst_ptrs, x_vals, mask=m) # ------------------------- # STEP 3/4/5: causal conv1d (+ optional SiLU) and store output # This is original STEP3~5, but per-lane and without debug_barrier. # ------------------------- x_base_1d = x_base o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim # accumulator preload (bias) acc_preload = acc_bias # compute each token; keep tl.range so varlen can use seqlen_run as runtime trip count (like original) for idx_token in tl.range(seqlen_run): acc = acc_preload # same selection logic as original (unrolled by KERNEL_WIDTH) matrix_w = w_col0 matrix_x = col0 for j in tl.static_range(KERNEL_WIDTH): if KERNEL_WIDTH == 1: # only x[t] * w0 x_ptrs_1d = x_base_1d + idx_token * stride_x_token matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) matrix_w = w_col0 elif KERNEL_WIDTH == 2: if j == 1: matrix_w = w_col1 x_ptrs_1d = x_base_1d + idx_token * stride_x_token matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 3: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 x_ptrs_1d = x_base_1d + idx_token * stride_x_token matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 4: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 5: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 matrix_x = col3 elif j == 4: matrix_w = w_col4 x_ptrs_1d = x_base_1d + idx_token * stride_x_token matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 6: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 matrix_x = col3 elif j == 4: matrix_w = w_col4 matrix_x = col4 elif j == 5: matrix_w = w_col5 x_ptrs_1d = x_base_1d + idx_token * stride_x_token matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N] # roll history window if KERNEL_WIDTH == 2: col0 = matrix_x elif KERNEL_WIDTH == 3: col0 = col1 col1 = matrix_x elif KERNEL_WIDTH == 4: col0 = col1 col1 = col2 col2 = matrix_x elif KERNEL_WIDTH == 5: col0 = col1 col1 = col2 col2 = col3 col3 = matrix_x elif KERNEL_WIDTH == 6: col0 = col1 col1 = col2 col2 = col3 col3 = col4 col4 = matrix_x if SILU_ACTIVATION: acc = acc / (1.0 + tl.exp(-acc)) # store output o_ptrs = o_base_1d + idx_token * stride_o_token tl.store(o_ptrs, acc, mask=lane_active & mask_w) def causal_conv1d_update_npu( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, activation: bool | str | None = None, conv_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, query_start_loc: torch.Tensor | None = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, validate_data=False, ): """ x: Input tensor which can take the following shapes: - `[batch, dim]` - single token prediction - `[batch, dim, seqlen]` - single or multiple tokens prediction - `[num_tokens, dim]` - continuous batching, where num_tokens is the total tokens of all sequences in that batch conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. block_idx_last_scheduled_token: (batch,), dtype int32 The pointer into conv_state_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 The pointer into conv_state_indices, where the cache block containing the initial state is located. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. This is used in speculative decoding, where the conv_state is updated in a sliding window manner. query_start_loc: (batch + 1,) int32 If not None, the inputs is given in a varlen fashion and this indicates the starting index of each sequence in the batch. max_query_len: int If query_start_loc is not None, this indicates the maximum query length in the batch. pad_slot_id: int if conv_state_indices is passed, lets the kernel identify padded entries that will not be processed, for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ weight = weight.transpose(0, 1).contiguous() conv_state = conv_state.transpose(1, 2).contiguous() if validate_data: assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] original_x_dtype = x.dtype x = x.to(conv_state.dtype) unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(1) if query_start_loc is None: batch, seqlen, dim = x.shape else: assert conv_state_indices is not None batch = conv_state_indices.size(0) dim = x.size(1) seqlen = max_query_len width, _ = weight.shape num_cache_lines, state_len_total,_ = conv_state.size() # overwrite-on-x strategy same as original out = x stride_w_width, stride_w_dim = weight.stride() if query_start_loc is None: stride_x_seq, stride_x_token,stride_x_dim = x.stride() stride_o_seq, stride_o_token, stride_o_dim = out.stride() else: stride_x_token, stride_x_dim = x.stride() stride_x_seq = 0 stride_o_token, stride_o_dim = out.stride() stride_o_seq = 0 stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride( ) stride_state_indices = conv_state_indices.stride( 0) if conv_state_indices is not None else 0 # effective state_len exactly as original if num_accepted_tokens is not None: eff_state_len = width - 1 + (seqlen - 1) else: eff_state_len = width - 1 np2_statelen = triton.next_power_of_2(eff_state_len) # -------- tiling heuristic-------- #keep program count around ~[80..160] # vector core 40 # TODO: use driver to get the vector core num CORE_HINT = 40 # channel tile: 512 when dim large (reduce tasks), else 256 block_n = 512 if dim >= 512 else 256 g = triton.cdiv(dim, block_n) target = 2 * CORE_HINT # ~80 b_tile_raw = max(1, (batch * g + target - 1) // target) # clamp to small set if b_tile_raw <= 1: b_tile = 1 elif b_tile_raw <= 2: b_tile = 2 elif b_tile_raw <= 4: b_tile = 4 else: b_tile = 8 # token chunk based on block_n (32KB UB idea); conservative t_chunk = 1 if block_n == 512 else 48 def grid(META): return ( triton.cdiv(batch, META["B_TILE"]), triton.cdiv(dim, META["BLOCK_N"]), ) _causal_conv1d_update_kernel_npu_tiled[grid]( x, weight, bias, conv_state, conv_state_indices, num_accepted_tokens, query_start_loc, block_idx_last_scheduled_token, initial_state_idx, out, batch, dim, seqlen, eff_state_len, num_cache_lines, stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, stride_w_width, stride_istate_seq, stride_istate_dim, stride_istate_token, stride_state_indices, stride_o_seq, stride_o_dim, stride_o_token, pad_slot_id, HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=block_n, B_TILE=b_tile, T_CHUNK=t_chunk, ) if unsqueeze: out = out.squeeze(1) return out.to(original_x_dtype)