# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2024, Tri Dao. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py import numpy as np import torch from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import tl, triton @triton.jit() def _causal_conv1d_fwd_kernel( # continuous batching # Pointers to matrices x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences w_ptr, # (dim, width) bias_ptr, initial_states_ptr, # conv_states_ptr cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains # the block indices relevant for each sequence # plus potential 0-padding at the beginning and at the end has_initial_states_ptr, query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, block_idx_first_scheduled_token, # (batch,) block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) num_computed_tokens, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions dim: tl.constexpr, seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides stride_x_dim: tl.constexpr, # stride to get to next feature-value, stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value stride_w_width: tl.constexpr, # stride to get to next width-axis value stride_istate_seq: tl.constexpr, stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, stride_cache_indices: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_APC_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): conv_states_ptr = initial_states_ptr conv_state_indices_ptr = cache_indices_ptr stride_conv_state_seq = stride_istate_seq stride_conv_state_dim = stride_istate_dim stride_conv_state_tok = stride_istate_token state_len = ( KERNEL_WIDTH - 1 ) # can be passed via argument if it's not the same as this value # one program handles one chunk in a single sequence # rather than mixing sequences - to make updating initial_states across sequences efficiently # single-sequence id idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) # BLOCK_N elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) if idx_seq == pad_slot_id: return sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) # find the actual sequence length seqlen = sequence_end_index - sequence_start_index B_size: tl.constexpr = stride_block_m * BLOCK_M if IS_APC_ENABLED: # Handle the case if prefix caching is enabled. # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" # Get the length of the completed sequence so far and compute the offset. current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq) current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) sequence_completed_index = tl.load(num_computed_tokens + idx_seq) # Compute the offset where the first stride_block_m-aligned first full block is # Value in "token-space" sequence_completed_offset_token = sequence_completed_index % B_size seq_completed_offset = B_size - sequence_completed_offset_token seq_end_offset = (seqlen - seq_completed_offset) % B_size last_full_block_token_index = sequence_end_index - seq_end_offset # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one if seq_end_offset == 0: last_full_block_token_index = last_full_block_token_index - B_size # Get the number of blocks to be filled for the current sequence # If n_block_to_fill = 0, then only the state at the sequence end is stored n_block_to_fill = current_last_index - current_first_index # Get the index of the init block conv_state_init_index = tl.load(initial_state_idx + idx_seq) else: n_block_to_fill = 0 current_last_index = 0 conv_state_init_index = 0 current_first_index = 0 last_full_block_token_index = 0 token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) # base of the sequence x_base = ( x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim ) # [BLOCK_N,] # cache_idx conv_states_input_coord = tl.load( conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index ).to(tl.int64) if USE_PAD_SLOT: # noqa if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return conv_states_base = ( conv_states_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) ) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] # Does 2 things: # 1. READ prior-block init-state data - [done by every Triton programs] # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] if chunk_offset == 0: # read from conv_states load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH == 3: conv_states_ptrs = prior_tokens # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0) conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH == 4: conv_states_ptrs = prior_tokens # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0) conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH == 5: conv_states_ptrs = prior_tokens # [BLOCK_N] col3 = tl.load(conv_states_ptrs, mask_w, 0.0) conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0) conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) else: # prior-tokens are zeros if KERNEL_WIDTH >= 2: # STRATEGY1 # first chunk and does not have prior-token, so just set to 0 col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 3: # STRATEGY1 col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 4: # STRATEGY1 col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 5: # STRATEGY1 col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) # STEP 2: # here prepare data for updating conv_state if ( state_len <= seqlen ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) # just read from 'x' # copy 'x' data to conv_state # load only 'x' data (and set 0 before 'x' if seqlen < state_len) idx_tokens_last = (seqlen - state_len) + tl.arange( 0, NP2_STATELEN ) # [BLOCK_M] x_ptrs = ( x_ptr + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + (idx_feats * stride_x_dim)[None, :] ) # [BLOCK_M,BLOCK_N,] mask_x = ( (idx_tokens_last >= 0)[:, None] & (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # Compute the offset where the last block should be written in the conv_states conv_states_output_coord = tl.load( conv_state_indices_ptr + idx_seq * stride_cache_indices + current_last_index ).to(tl.int64) conv_states_ptrs_target = ( conv_states_ptr + (conv_states_output_coord * stride_conv_state_seq) # Offset from seq + (idx_feats * stride_conv_state_dim) )[None, :] + ( # [BLOCK_N,] idx_tokens_conv * stride_conv_state_tok )[:, None] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler tl.store(conv_states_ptrs_target, loaded_x, mask) else: if load_init_state: # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] conv_states_ptrs_source = ( conv_states_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] mask = ( (conv_states_input_coord < num_cache_lines) & ((idx_tokens_conv + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :] ) conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) VAL = state_len - seqlen x_ptrs = ( x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] ) # [BLOCK_M, BLOCK_N] mask_x = ( (idx_tokens_conv - VAL >= 0)[:, None] & (idx_tokens_conv - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load new_conv_state = tl.where( mask, conv_state, loaded_x ) # BUG in 'tl.where' which requires a barrier before this conv_states_ptrs_target = ( conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ None, : ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # load_init_state == False # update conv_state by shifting left, BUT # set cols prior to 'x' as zeros + cols from 'x' idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] VAL = state_len - seqlen x_ptrs = ( x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] ) # [BLOCK_M, BLOCK_N] mask_x = ( (idx_tokens_conv - VAL >= 0)[:, None] & (idx_tokens_conv - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index new_conv_state = tl.load(x_ptrs, mask_x, 0.0) conv_states_ptrs_target = ( conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ None, : ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # chunk_offset > 0 # read prior-token data from `x` load_init_state = True prior_tokens = x_base + (token_offset - 1) * stride_x_token mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 3: conv_states_ptrs = prior_tokens # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 4: conv_states_ptrs = prior_tokens # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 5: # ruff: noqa: F841 conv_states_ptrs = prior_tokens # [BLOCK_N] col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") # Store intermediate states aligned with stride_block_m # The additional states are cached starting from the last stride_block_m. # For example: # If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved. # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last # stride_block_m are cached. # For example chunk_offset = n_block_to_fill stores the state at last_full_block if (chunk_offset - 1) < n_block_to_fill: # Store the states at the chunk boundaries from the start of the sequence idx_tokens_last = ( last_full_block_token_index - (n_block_to_fill - chunk_offset) * B_size - state_len ) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] x_ptrs = ( x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (idx_feats * stride_x_dim)[None, :] ) # [BLOCK_M,BLOCK_N,] mask_x = (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[ None, : ] # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # cache_idx conv_states_output_coord = tl.load( conv_state_indices_ptr + idx_seq * stride_cache_indices + current_first_index + (chunk_offset - 1) ).to(tl.int64) conv_states_ptrs_target = ( conv_states_ptr + (conv_states_output_coord * stride_conv_state_seq) # Offset from seq + (idx_feats * stride_conv_state_dim) )[None, :] + ( # [BLOCK_N,] idx_tokens_conv * stride_conv_state_tok )[:, None] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler tl.store(conv_states_ptrs_target, loaded_x, mask) if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( tl.float32 ) # [BLOCK_N] else: acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) x_base_1d = x_base + token_offset * stride_x_token # starting of chunk # PRE-LOAD WEIGHTS mask_w = idx_feats < dim if KERNEL_WIDTH >= 2: w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor w_col0 = tl.load(w_ptrs, mask_w, other=0.0) w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor w_col1 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 3: w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor w_col2 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 4: w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor w_col3 = tl.load(w_ptrs, mask_w, other=0.0) mask_x_1d = idx_feats < dim for idx_token in range(segment_len): acc = acc_preload matrix_w = w_col0 matrix_x = col0 for j in tl.static_range(KERNEL_WIDTH): if KERNEL_WIDTH == 2: if j == 1: # KERNEL_WIDTH-1: matrix_w = w_col1 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 3: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 4: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) acc += matrix_x * matrix_w # [BLOCK_N] if KERNEL_WIDTH == 2: col0 = matrix_x elif KERNEL_WIDTH == 3: col0 = col1 col1 = matrix_x elif KERNEL_WIDTH == 4: col0 = col1 col1 = col2 col2 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < segment_len) & ( idx_feats < dim ) # token-index # feature-index o_ptrs = ( o_ptr + (sequence_start_index + token_offset + idx_token) * stride_o_token + (idx_feats * stride_o_dim) ) tl.store(o_ptrs, acc, mask=mask_1d) def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, conv_states: torch.Tensor, query_start_loc: torch.Tensor, cache_indices: torch.Tensor | None = None, has_initial_state: torch.Tensor | None = None, activation: str | None = "silu", pad_slot_id: int = PAD_SLOT_ID, block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, num_computed_tokens: torch.Tensor | None = None, block_size_to_align=0, metadata=None, validate_data=False, ): """support varlen + continuous batching when x is 2D tensor x: (dim,cu_seq_len) cu_seq_len = total tokens of all seqs in that batch sequences are concatenated from left to right for varlen weight: (dim, width) conv_states: (...,dim,width - 1) itype updated inplace if cache_indices are not provided [it use `cache_indices` to get the index to the cache of conv_state for that sequence conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' ] query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. if x = [5, 1, 1, 1] <- continuous batching (batch=4) then query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is the ending index of the last sequence [length(query_start_loc)-1 == batch] for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 indicates the corresponding state index, like so: conv_state = conv_states[cache_indices[batch_id]] has_initial_state: (batch) bool indicates whether should the kernel take the current state as initial state for the calculations [single boolean for each sequence in the batch: True or False] bias: (dim,) activation: either None or "silu" or "swish" or True pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 block_idx_first_scheduled_token: (batch,), dtype int32 The pointer into cache_indices, where the first cache block to be filled is located. block_idx_last_scheduled_token: (batch,), dtype int32 The pointer into cache_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 The pointer into cache_indices, where the cache block containing the initial state is located. num_computed_tokens: (batch,), dtype int32 The number of tokens already completed for each sequence block_size_to_align: int The block size to align the cached states to out: same shape as `x` """ if isinstance(activation, bool) and activation: activation = "silu" args = None # Store original dtype to cast back at the end original_x_dtype = x.dtype x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: nums_dict = metadata.nums_dict args = nums_dict batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr else: seqlens = query_start_loc.diff().to("cpu") args = seqlens MAX_NUM_PROGRAMS = 1024 batch_ptr = torch.full( (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking which seq-idx the Triton program is handling token_chunk_offset_ptr = torch.full( (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking BLOCK_M-based index in the sequence the Triton program is handling is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) dim, cu_seqlen = x.shape _, width = weight.shape state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) padded_batch = query_start_loc.size(0) - 1 stride_x_dim = x.stride(0) stride_x_token = x.stride(1) stride_w_dim = weight.stride(0) stride_w_width = weight.stride(1) stride_istate_seq = 0 stride_istate_dim = 0 stride_istate_token = 0 num_cache_lines = 0 BLOCK_M = 8 if conv_states is not None: # extensions to support vLLM: # 1. conv_states is used to replaced initial_states # 2. conv_states serve as a cache with num cache lines can be larger than batch size # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) assert ( num_cache_lines == conv_states.shape[0] and dim == conv_states.shape[1] and width - 1 <= conv_states.shape[2] ) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) assert stride_istate_dim == 1 if out.dim() == 2: stride_o_dim = out.stride(0) stride_o_token = out.stride(1) else: stride_o_dim = out.stride(1) stride_o_token = out.stride(2) stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 assert query_start_loc is not None assert query_start_loc.dim() == 1 assert x.stride(0) == 1 or x.stride(1) == 1 if bias is not None: assert bias.dim() == 1 assert dim == bias.size(0) if cache_indices is not None: assert cache_indices.dim() == 1 assert padded_batch == cache_indices.size(0) if has_initial_state is not None: assert has_initial_state.size() == (padded_batch,) assert conv_states is not None, ( "ERROR: `has_initial_state` is used, which needs also `conv_states`" ) assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" if block_size_to_align is not None and block_size_to_align > 0: assert (block_size_to_align % BLOCK_M) == 0, ( "The mamba block size needs to be divisible by the BLOCK_M" ) else: block_size_to_align = BLOCK_M if metadata is None: def num_program(META, seqlens): tot = 0 mlist = [] offsetlist = [] # type: ignore nums = -(-seqlens // META["BLOCK_M"]) tot = nums.sum().item() mlist = np.repeat(np.arange(len(nums)), nums) for idx, num in enumerate(nums): offsetlist.extend( range(num) ) # chunk-idx if a sequence is split into multiple chunks if META["batch_ptr"].nelement() < len(mlist): newlen = len(mlist) + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= len(mlist): META["batch_ptr"][0 : len(mlist)].copy_( torch.from_numpy(np.array(mlist)) ) META["token_chunk_offset_ptr"][0 : len(mlist)].copy_( torch.from_numpy(np.array(offsetlist)) ) META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( META["x_ptr"].device ) return tot else: def num_program(META, nums_dict): tot = nums_dict[META["BLOCK_M"]]["tot"] mlist = nums_dict[META["BLOCK_M"]]["mlist"] mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][ "token_chunk_offset_ptr" ] else: if META["batch_ptr"].nelement() < mlist_len: newlen = mlist_len + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= mlist_len: META["batch_ptr"][0:mlist_len].copy_(mlist) META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) return tot def grid(META): return ( num_program(META, args), triton.cdiv(dim, META["BLOCK_N"]), ) if batch_ptr.device != x.device: batch_ptr = batch_ptr.to(x.device) token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) _causal_conv1d_fwd_kernel[grid]( # Pointers to matrices x, weight, bias, conv_states, cache_indices, has_initial_state, query_start_loc, batch_ptr, token_chunk_offset_ptr, block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, num_computed_tokens, out, # Matrix dimensions dim, cu_seqlen, num_cache_lines, # stride stride_x_dim, stride_x_token, stride_w_dim, stride_w_width, stride_istate_seq, stride_istate_dim, stride_istate_token, stride_cache_indices, stride_o_dim, stride_o_token, block_size_to_align // BLOCK_M, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_APC_ENABLED=block_idx_last_scheduled_token is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, # launch_cooperative_grid=True BLOCK_M=BLOCK_M, BLOCK_N=256, num_stages=2, ) return out.to(original_x_dtype) @triton.jit() def _causal_conv1d_update_kernel( # Pointers to matrices x_ptr, # (batch, dim, seqlen) w_ptr, # (dim, width) bias_ptr, conv_state_ptr, conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, dim: tl.constexpr, seqlen: tl.constexpr, state_len: tl.constexpr, num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides stride_x_seq: tl.constexpr, stride_x_dim: tl.constexpr, stride_x_token: tl.constexpr, stride_w_dim: tl.constexpr, stride_w_width: tl.constexpr, stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, stride_state_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, IS_APC_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, ): # ruff: noqa: E501 idx_seq = tl.program_id(0) if idx_seq >= batch: return # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) if IS_APC_ENABLED: # Get the state from the initial_state_idx conv_state_init = tl.load(initial_state_idx + idx_seq) current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) else: conv_state_init = 0 current_last_index = 0 # cache_idx conv_states_input_coord = tl.load( conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init ).to(tl.int64) if USE_PAD_SLOT: # noqa if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return if IS_VARLEN: query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) # revise state_len and seqlen state_len = state_len - (seqlen - (query_end_index - query_start_index)) seqlen = query_end_index - query_start_index x_offset = query_start_index * stride_x_token o_offset = query_start_index * stride_o_token else: query_start_index = idx_seq * seqlen query_end_index = query_start_index + seqlen x_offset = idx_seq * stride_x_seq o_offset = idx_seq * stride_o_seq if query_start_index == query_end_index: return if IS_SPEC_DECODING: # The rolling of conv state: # # Before forward, the conv_state is: # [history1, history2, ..., historyM]. # # After forward, the conv_state becomes: # [history2, ..., historyM, draft1, draft2, ..., draftN]. # # After acceptance, it becomes: # # - accept 1 tokens: [history2, ..., historyM, draft1] # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. conv_state_token_offset = ( tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 ) else: conv_state_token_offset = 0 # STEP 1: READ init_state data conv_states_base = ( conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) ) mask_w = idx_feats < dim prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok if KERNEL_WIDTH >= 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH >= 3: conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH >= 4: conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH >= 5: conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] col3 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH >= 6: conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] col4 = tl.load(conv_states_ptrs, mask_w, 0.0) # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # With speculative decoding, the conv_state updates works in a sliding # window manner, at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ :, None ] ) # [BLOCK_M, BLOCK_N] mask = ( (conv_states_input_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :] ) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] x_ptrs = ( x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] ) # [BLOCK_M, BLOCK_N] mask_x = ( (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) tl.debug_barrier() new_conv_state = tl.where(mask, conv_state, loaded_x) # Get the state from the initial_state_idx # cache_idx conv_states_offset = tl.load( conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index ).to(tl.int64) conv_state_ptrs_target = ( conv_state_ptr + (conv_states_offset * stride_conv_state_seq) # Offset from seq + (idx_feats * stride_conv_state_dim) )[None, :] + ( # [BLOCK_N,] idx_tokens * stride_conv_state_tok )[:, None] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) # STEP 3: init accumulator if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( tl.float32 ) # [BLOCK_N] else: acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) # STEP 4: # PRE-LOAD WEIGHTS # first kernel column, configured for weights to handle BLOCK_N features in range w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] mask_w = idx_feats < dim if KERNEL_WIDTH >= 2: w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor w_col0 = tl.load(w_ptrs, mask_w, other=0.0) w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor w_col1 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 3: w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor w_col2 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 4: w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor w_col3 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 5: w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor w_col4 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 6: w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor w_col5 = tl.load(w_ptrs, mask_w, other=0.0) x_base_1d = x_base # starting of chunk [BLOCK_N] mask_x_1d = idx_feats < dim # STEP 5: compute each token for idx_token in tl.range(seqlen): acc = acc_preload matrix_w = w_col0 matrix_x = col0 for j in tl.static_range(KERNEL_WIDTH): if KERNEL_WIDTH == 2: if j == 1: # KERNEL_WIDTH-1: matrix_w = w_col1 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 3: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 4: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 5: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 matrix_x = col3 elif j == 4: matrix_w = w_col4 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 6: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 matrix_x = col3 elif j == 4: matrix_w = w_col4 matrix_x = col4 elif j == 5: matrix_w = w_col5 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) acc += matrix_x * matrix_w # [BLOCK_N] if KERNEL_WIDTH == 2: col0 = matrix_x elif KERNEL_WIDTH == 3: col0 = col1 col1 = matrix_x elif KERNEL_WIDTH == 4: col0 = col1 col1 = col2 col2 = matrix_x elif KERNEL_WIDTH == 5: col0 = col1 col1 = col2 col2 = col3 col3 = matrix_x elif KERNEL_WIDTH == 6: col0 = col1 col1 = col2 col2 = col3 col3 = col4 col4 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < seqlen) & ( idx_feats < dim ) # token-index # feature-index o_ptrs = ( o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) ) tl.store(o_ptrs, acc, mask=mask_1d) def causal_conv1d_update( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, activation: bool | str | None = None, conv_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, query_start_loc: torch.Tensor | None = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, validate_data=False, ): """ x: Input tensor which can take the following shapes: - `[batch, dim]` - single token prediction - `[batch, dim, seqlen]` - single or multiple tokens prediction - `[num_tokens, dim]` - continuous batching, where num_tokens is the total tokens of all sequences in that batch conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. block_idx_last_scheduled_token: (batch,), dtype int32 The pointer into conv_state_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 The pointer into conv_state_indices, where the cache block containing the initial state is located. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. This is used in speculative decoding, where the conv_state is updated in a sliding window manner. query_start_loc: (batch + 1,) int32 If not None, the inputs is given in a varlen fashion and this indicates the starting index of each sequence in the batch. max_query_len: int If query_start_loc is not None, this indicates the maximum query length in the batch. pad_slot_id: int if conv_state_indices is passed, lets the kernel identify padded entries that will not be processed, for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] original_x_dtype = x.dtype x = x.to(conv_state.dtype) unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) if query_start_loc is None: batch, dim, seqlen = x.shape else: assert conv_state_indices is not None batch = conv_state_indices.size(0) dim = x.size(1) seqlen = max_query_len _, width = weight.shape # conv_state: (..., dim, state_len), where state_len >= width - 1 num_cache_lines, _, state_len = conv_state.size() if validate_data: assert dim == weight.size(0) assert conv_state.stride(-2) == 1, ( f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" ) assert state_len >= width - 1 # when above happens, we don't shift-left to keep any records in conv_state assert dim == conv_state.size(1) if conv_state_indices is None: assert conv_state.size(0) >= batch else: assert (batch,) == conv_state_indices.shape assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x stride_w_dim, stride_w_width = weight.stride() if query_start_loc is None: # X (batch, dim, seqlen) stride_x_seq, stride_x_dim, stride_x_token = x.stride() stride_o_seq, stride_o_dim, stride_o_token = out.stride() else: # X (dim, cu_seqlen) stride_x_token, stride_x_dim = x.stride() stride_x_seq = 0 stride_o_token, stride_o_dim = out.stride() stride_o_seq = 0 stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() stride_state_indices = ( conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) if num_accepted_tokens is not None: state_len = width - 1 + (seqlen - 1) # effective state_len needed else: state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) def grid(META): return ( batch, triton.cdiv(dim, META["BLOCK_N"]), ) _causal_conv1d_update_kernel[grid]( # Pointers to matrices x, weight, bias, conv_state, conv_state_indices, num_accepted_tokens, query_start_loc, block_idx_last_scheduled_token, initial_state_idx, out, # Matrix dimensions batch, dim, seqlen, state_len, num_cache_lines, # stride stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, stride_w_width, stride_istate_seq, stride_istate_dim, stride_istate_token, stride_state_indices, stride_o_seq, stride_o_dim, stride_o_token, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, ) if unsqueeze: out = out.squeeze(-1) return out.to(original_x_dtype)