From 7372225bcb0bd4896f43e989cef8109bfe45b13c Mon Sep 17 00:00:00 2001 From: Qi Mao Date: Fri, 26 Dec 2025 09:12:30 +0800 Subject: [PATCH] [FIX] Update _causal_conv1d_update_kernel for Efficient Conv State Handling on NPU (#5322) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description: This PR updates the implementation of the Triton operator for deployment on NPU devices, focusing on optimizing grid size and memory handling based on NPU limitations. Design Plan: Grid Calculation: The grid size is now dynamically calculated by batch and dim to ensure that the number of programs executed does not exceed the NPU's vector core capacity. This ensures optimal parallelism without overloading the hardware. Data Block Handling: Due to the limited on-chip memory (UB) on Ascend NPUs, this implementation splits large data into smaller chunks of 32k or less per block. The kernel performs a for-loop to process the data in these smaller chunks, minimizing memory usage and avoiding potential overflows. Changes Compared to GPU Implementation: Grid and Block Sizing: For GPU, the grid and block size were determined based on available thread counts and memory size. In contrast, the NPU version dynamically adjusts these parameters using B_TILE and BLOCK_N to optimize for NPU’s architecture. Memory Chunking: The original GPU implementation did not require chunking due to the higher available memory and processing capacity. For the NPU, data is divided into smaller chunks (32k or smaller) to comply with memory constraints on the device. The kernel has been modified to handle this chunking mechanism inside a loop. Optimized Thread Usage: The NPU implementation takes into account the hardware-specific thread limit (24 threads per vector core), ensuring that the number of active programs is aligned with the NPU's vector core count, avoiding over-subscription that would lead to serial processing. This PR ensures that the operator functions efficiently on Ascend NPU, considering hardware limitations while maintaining the same functionality and input parameters as the GPU implementation. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: maoxx241 --- .../nightly/ops/triton/test_causal_conv1d.py | 153 ++- vllm_ascend/ops/triton/mamba/causal_conv1d.py | 1162 ++++++----------- 2 files changed, 563 insertions(+), 752 deletions(-) diff --git a/tests/e2e/nightly/ops/triton/test_causal_conv1d.py b/tests/e2e/nightly/ops/triton/test_causal_conv1d.py index fe65364f..fe4eac0f 100644 --- a/tests/e2e/nightly/ops/triton/test_causal_conv1d.py +++ b/tests/e2e/nightly/ops/triton/test_causal_conv1d.py @@ -6,6 +6,8 @@ import torch.nn.functional as F from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID, causal_conv1d_fn) +from vllm_ascend.ops.triton.mamba.causal_conv1d import \ + causal_conv1d_update_npu as causal_conv1d_update def validate_cmp(y_cal, y_ref, dtype, device='npu'): @@ -156,17 +158,13 @@ def causal_conv1d_fn_pytorch( @pytest.mark.parametrize('has_initial_state', [False, True]) -@pytest.mark.parametrize('itype', - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('silu_activation', [False, True]) -@pytest.mark.parametrize('has_bias', [False, True]) -@pytest.mark.parametrize('seq_len', [[ - 1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, - 2048, 4096 -]]) +@pytest.mark.parametrize('itype', [torch.bfloat16]) +@pytest.mark.parametrize('silu_activation', [True]) +@pytest.mark.parametrize('has_bias', [True]) +@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]]) @pytest.mark.parametrize('extra_state_len', [0, 2]) -@pytest.mark.parametrize('width', [2, 3, 4]) -@pytest.mark.parametrize('dim', [64, 4160]) +@pytest.mark.parametrize('width', [2, 4]) +@pytest.mark.parametrize('dim', [4160]) def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias, silu_activation, itype, has_initial_state): @@ -227,4 +225,137 @@ def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias, query_start_loc=query_start_loc) validate_cmp(out, out_ref, itype) - validate_cmp(conv_states, conv_states_ref, itype) \ No newline at end of file + validate_cmp(conv_states, conv_states_ref, itype) + + +def causal_conv1d_update_ref(x, + conv_state, + weight, + bias=None, + activation=None, + cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to( + weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange( + -(width - 1), 0, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand( + -1, dim, -1)) + x_new = torch.cat([conv_state.gather(2, width_idx), x], + dim=-1).to(weight.dtype) + copy_idx = torch.arange( + seqlen, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, + state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, + groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 3]) +@pytest.mark.parametrize("width", [3, 4]) +@pytest.mark.parametrize("dim", [2048 + 16, 4096]) +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch_size", [3, 64]) +def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, + width, seqlen, has_bias, + silu_activation, itype): + device = "npu" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + # total_entries = number of cache line + total_entries = 10 * batch_size + + # x will be (batch, dim, seqlen) with contiguous along dim-axis + x = torch.randn(padded_batch_size, seqlen, dim, device=device, + dtype=itype).transpose(1, 2) + + x_ref = x.clone() + + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat( + [ + conv_state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) + + # conv_state will be (cache_lines, dim, state_len) + # with contiguous along dim-axis + conv_state = torch.randn(total_entries, + width - 1, + dim, + device=device, + dtype=itype).transpose(1, 2) + + conv_state_for_padding_test = conv_state.clone() + + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], + conv_state_ref, + weight, + bias, + activation=activation) + + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.equal(conv_state[unused_states_bool], + conv_state_for_padding_test[unused_states_bool]) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index fafa00c2..38b838d8 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -441,317 +441,407 @@ def causal_conv1d_fn(x: torch.Tensor, return out -# TODO copied from vllm and it needs to be optimized -@triton.jit() -def _original_causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - conv_state_indices_ptr, - num_accepted_tokens_ptr, - query_start_loc_ptr, # (batch + 1) - block_idx_last_scheduled_token, # (batch,) - initial_state_idx, # (batch,) - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_APC_ENABLED: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, +@triton.jit +def _causal_conv1d_update_kernel_npu_tiled( + # Pointers + x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, # (num_cache_lines, dim, state_len) + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + o_ptr, # same shape as x_ptr + batch: tl.int32, + dim: tl.constexpr, + seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen + state_len: tl.constexpr, # effective state_len computed in wrapper + num_cache_lines: tl.constexpr, + + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + + # others + pad_slot_id: tl.constexpr, + + # Meta + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, # <= 6 + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + + # tiling + BLOCK_N: tl.constexpr, # channel tile (C_TILE) + B_TILE: tl.constexpr, # batch tile + T_CHUNK: tl.constexpr, # token chunk for state update ): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return + # program ids + pid_b = tl.program_id(0) # batch-tile id + pid_c = tl.program_id(1) # channel-tile id - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_APC_ENABLED: - # Get the state from the initial_state_idx - conv_state_init = tl.load(initial_state_idx + idx_seq) - current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) - else: - conv_state_init = 0 - current_last_index = 0 - - # cache_idx - conv_states_input_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) - - if USE_PAD_SLOT: # noqa - if conv_states_input_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_VARLEN: - query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) - query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to( - tl.int64) - # revise state_len and seqlen - state_len = state_len - (seqlen - - (query_end_index - query_start_index)) - seqlen = query_end_index - query_start_index - x_offset = query_start_index * stride_x_token - o_offset = query_start_index * stride_o_token - else: - query_start_index = idx_seq * seqlen - query_end_index = query_start_index + seqlen - x_offset = idx_seq * stride_x_seq - o_offset = idx_seq * stride_o_seq - - if query_start_index == query_end_index: - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = ( - tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1) - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_states_input_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) + # channel indices for this program + idx_feats = pid_c * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N] mask_w = idx_feats < dim - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + # preload weights once per program (shared by B_TILE sequences) + w_base = w_ptr + idx_feats * stride_w_dim + # define to avoid "undefined" in branches + w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + if KERNEL_WIDTH >= 1: + w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 6: - conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] - col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # With speculative decoding, the conv_state updates works in a sliding - # window manner, at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * - stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_states_input_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] - - x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - # Get the state from the initial_state_idx - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - current_last_index).to(tl.int64) - conv_state_ptrs_target = ( - conv_state_ptr + - (conv_states_offset * stride_conv_state_seq) # Offset from seq - + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] - idx_tokens * stride_conv_state_tok)[:, None] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator + # bias vector once per program if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, + other=0.0).to(tl.float32) else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32) - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 5: - w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor - w_col4 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 6: - w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor - w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + # token index vector for chunked copy + tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK] - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim + # process B_TILE sequences inside the same program instance + for bi in tl.static_range(0, B_TILE): + b = pid_b * B_TILE + bi # scalar tl.int32 + lane_active = b < batch # scalar predicate - # STEP 5: compute each token - for idx_token in tl.range(seqlen): - acc = acc_preload + # ------------------------- + # APC mapping (optional) + # ------------------------- + if IS_APC_ENABLED: + conv_state_init = tl.load(initial_state_idx + b, + mask=lane_active, + other=0).to(tl.int32) + current_last_index = tl.load(block_idx_last_scheduled_token + b, + mask=lane_active, + other=0).to(tl.int32) + else: + conv_state_init = tl.full((), 0, tl.int32) + current_last_index = tl.full((), 0, tl.int32) - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): + # input cache line + conv_states_input_coord = tl.load(conv_state_indices_ptr + + b * stride_state_indices + + conv_state_init, + mask=lane_active, + other=0).to(tl.int64) + + if USE_PAD_SLOT: + lane_active = lane_active & (conv_states_input_coord + != pad_slot_id) + + # ------------------------- + # varlen (optional): revise seqlen_run and state_len_run like original kernel does + # ------------------------- + if IS_VARLEN: + qs = tl.load(query_start_loc_ptr + b, mask=lane_active, + other=0).to(tl.int64) + qe = tl.load(query_start_loc_ptr + (b + 1), + mask=lane_active, + other=0).to(tl.int64) + seqlen_run = (qe - qs).to(tl.int32) + # revise effective state_len for shorter sequences (same formula as original) + state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32) + x_offset = (qs * stride_x_token).to(tl.int64) + o_offset = (qs * stride_o_token).to(tl.int64) + else: + seqlen_run = tl.full((), seqlen, tl.int32) + state_len_run = tl.full((), state_len, tl.int32) + x_offset = (b * stride_x_seq).to(tl.int64) + o_offset = (b * stride_o_seq).to(tl.int64) + + # empty sequence -> skip (avoid early return because other lanes in tile) + lane_active = lane_active & (seqlen_run > 0) + + # ------------------------- + # spec decoding offset (optional) + # ------------------------- + if IS_SPEC_DECODING: + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + b, mask=lane_active, + other=1).to(tl.int64) - 1) + shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode + else: + conv_state_token_offset = tl.full((), 0, tl.int64) + shift = seqlen_run # normal mode shift by seqlen + + # ------------------------- + # STEP 1: read initial history cols BEFORE state update (out==x safe) + # ------------------------- + conv_states_base = (conv_state_ptr + + conv_states_input_coord * stride_conv_state_seq + + idx_feats * stride_conv_state_dim) + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + + # define history vectors as zeros then load conditionally + col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + if KERNEL_WIDTH >= 2: + col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 3: + col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 4: + col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 5: + col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 6: + col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + + # ------------------------- + # STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block) + # Semantics: conv_state <- concat(old_state, x)[-state_len_run:]. + # - If seqlen_run >= state_len_run: dst[:] = x[seqlen_run - state_len_run : seqlen_run] + # - Else: keep = state_len_run - seqlen_run, + # dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run] + # ------------------------- + # output cache line + conv_states_offset = tl.load(conv_state_indices_ptr + + b * stride_state_indices + + current_last_index, + mask=lane_active, + other=0).to(tl.int64) + + use_shift = (seqlen_run < state_len_run) + use_tail = (seqlen_run >= state_len_run) + + zero_i32 = tl.full((), 0, tl.int32) + keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), + zero_i32).to(tl.int32) + tail_start = tl.where(use_tail, (seqlen_run - state_len_run), + zero_i32).to(tl.int32) + + # base pointers + state_src_base = (conv_state_ptr + + conv_states_input_coord * stride_conv_state_seq + + conv_state_token_offset * stride_conv_state_tok + + idx_feats * stride_conv_state_dim) + state_dst_base = (conv_state_ptr + + conv_states_offset * stride_conv_state_seq + + idx_feats * stride_conv_state_dim) + + x_base = x_ptr + x_offset + idx_feats * stride_x_dim + + # A) shift old state into dst[0:keep_shift) (only when seqlen_run < state_len_run) + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] + src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK] + m_tok = use_shift & (dst_tok < keep_shift) & ( + src_tok < state_len_run) & (dst_tok < state_len_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( + conv_states_input_coord + < num_cache_lines) & (conv_states_offset < num_cache_lines) + + src_ptrs = state_src_base[ + None, :] + src_tok[:, None] * stride_conv_state_tok + dst_ptrs = state_dst_base[ + None, :] + dst_tok[:, None] * stride_conv_state_tok + vals = tl.load(src_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, vals, mask=m) + + # B) append x into dst[keep_shift : keep_shift+seqlen_run) (only when seqlen_run < state_len_run) + for t0 in tl.static_range(0, seqlen, T_CHUNK): + x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] + dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK] + m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok + < state_len_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( + conv_states_offset < num_cache_lines) + + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = state_dst_base[ + None, :] + dst_tok[:, None] * stride_conv_state_tok + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + + # C) if seqlen_run >= state_len_run, overwrite dst with the tail of x + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] + x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK] + m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( + conv_states_offset < num_cache_lines) + + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = state_dst_base[ + None, :] + dst_tok[:, None] * stride_conv_state_tok + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + + # ------------------------- + # STEP 3/4/5: causal conv1d (+ optional SiLU) and store output + # This is original STEP3~5, but per-lane and without debug_barrier. + # ------------------------- + x_base_1d = x_base + o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim + + # accumulator preload (bias) + acc_preload = acc_bias + + # compute each token; keep tl.range so varlen can use seqlen_run as runtime trip count (like original) + for idx_token in tl.range(seqlen_run): + acc = acc_preload + + # same selection logic as original (unrolled by KERNEL_WIDTH) + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 1: + # only x[t] * w0 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + matrix_w = w_col0 + elif KERNEL_WIDTH == 2: + if j == 1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + + acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N] + + # roll history window if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = matrix_x elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = matrix_x elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = matrix_x elif KERNEL_WIDTH == 5: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x elif KERNEL_WIDTH == 6: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - matrix_x = col4 - elif j == 5: - matrix_w = w_col5 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x - acc += matrix_x * matrix_w # [BLOCK_N] + if SILU_ACTIVATION: + acc = acc / (1.0 + tl.exp(-acc)) - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - elif KERNEL_WIDTH == 5: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = matrix_x - elif KERNEL_WIDTH == 6: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = col4 - col4 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim - ) # token-index # feature-index - o_ptrs = (o_ptr + o_offset + idx_token * stride_o_token + - (idx_feats * stride_o_dim)) - - tl.store(o_ptrs, acc, mask=mask_1d) + # store output + o_ptrs = o_base_1d + idx_token * stride_o_token + tl.store(o_ptrs, acc, mask=lane_active & mask_w) -# TODO copied from vllm and it needs to be optimized -def original_causal_conv1d_update( +def causal_conv1d_update_npu( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, @@ -818,6 +908,7 @@ def original_causal_conv1d_update( if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) + if query_start_loc is None: batch, dim, seqlen = x.shape else: @@ -825,36 +916,25 @@ def original_causal_conv1d_update( batch = conv_state_indices.size(0) dim = x.size(1) seqlen = max_query_len + _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() + num_cache_lines, _, state_len_total = conv_state.size() if validate_data: assert dim == weight.size(0) - assert conv_state.stride(-2) == 1, ( - f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - ) - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch, ) == conv_state_indices.shape - + assert conv_state.stride(-2) == 1 + assert state_len_total >= width - 1 assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this + assert weight.stride(1) == 1 - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + # overwrite-on-x strategy same as original out = x - stride_w_dim, stride_w_width = weight.stride() + stride_w_dim, stride_w_width = weight.stride() if query_start_loc is None: - # X (batch, dim, seqlen) stride_x_seq, stride_x_dim, stride_x_token = x.stride() stride_o_seq, stride_o_dim, stride_o_token = out.stride() else: - # X (dim, cu_seqlen) stride_x_token, stride_x_dim = x.stride() stride_x_seq = 0 stride_o_token, stride_o_dim = out.stride() @@ -862,22 +942,46 @@ def original_causal_conv1d_update( stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) - stride_state_indices = (conv_state_indices.stride(0) - if conv_state_indices is not None else 0) + stride_state_indices = conv_state_indices.stride( + 0) if conv_state_indices is not None else 0 + + # effective state_len exactly as original if num_accepted_tokens is not None: - state_len = width - 1 + (seqlen - 1) # effective state_len needed + eff_state_len = width - 1 + (seqlen - 1) else: - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) + eff_state_len = width - 1 + np2_statelen = triton.next_power_of_2(eff_state_len) + + # -------- tiling heuristic-------- + #keep program count around ~[80..160] + # vector core 40 + # TODO: use driver to get the vector core num + CORE_HINT = 40 + # channel tile: 512 when dim large (reduce tasks), else 256 + block_n = 512 if dim >= 512 else 256 + g = triton.cdiv(dim, block_n) + target = 2 * CORE_HINT # ~80 + b_tile_raw = max(1, (batch * g + target - 1) // target) + # clamp to small set + if b_tile_raw <= 1: + b_tile = 1 + elif b_tile_raw <= 2: + b_tile = 2 + elif b_tile_raw <= 4: + b_tile = 4 + else: + b_tile = 8 + + # token chunk based on block_n (32KB UB idea); conservative + t_chunk = 20 if block_n == 512 else 48 def grid(META): return ( - batch, + triton.cdiv(batch, META["B_TILE"]), triton.cdiv(dim, META["BLOCK_N"]), ) - _original_causal_conv1d_update_kernel[grid]( - # Pointers to matrices + _causal_conv1d_update_kernel_npu_tiled[grid]( x, weight, bias, @@ -888,13 +992,11 @@ def original_causal_conv1d_update( block_idx_last_scheduled_token, initial_state_idx, out, - # Matrix dimensions batch, dim, seqlen, - state_len, + eff_state_len, num_cache_lines, - # stride stride_x_seq, stride_x_dim, stride_x_token, @@ -907,9 +1009,7 @@ def original_causal_conv1d_update( stride_o_seq, stride_o_dim, stride_o_token, - # others pad_slot_id, - # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], @@ -918,431 +1018,11 @@ def original_causal_conv1d_update( IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=256, + BLOCK_N=block_n, + B_TILE=b_tile, + T_CHUNK=t_chunk, ) + if unsqueeze: out = out.squeeze(-1) return out.to(original_x_dtype) - - -@triton.jit() -def _causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - cache_seqlens_ptr, # circular buffer - conv_state_indices_ptr, - num_accepted_tokens_ptr, - intermediate_conv_window_ptr, - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_inter_seq: tl.constexpr, - stride_inter_step: tl.constexpr, - stride_inter_dim: tl.constexpr, - stride_inter_win: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, - SAVE_INTERMEDIATE: tl.constexpr, -): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return - - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) - else: - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + - idx_seq) - 1 - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) - mask_w = idx_feats < dim - - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok - if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # The conv_state updates works in a sliding window manner, - # at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + 1) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] - - x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = (conv_state_base + - (idx_tokens * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) - - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim - - # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - # mask_1d = (idx_token < seqlen) & ( - # idx_feats < dim - # ) # token-index # feature-index - maskL = idx_feats < dim - maskR = tl.full(maskL.shape, False, tl.int1) - mask_1d = tl.where(idx_token < seqlen, maskL, maskR) - - o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + - idx_token * stride_o_token + (idx_feats * stride_o_dim)) - - tl.store(o_ptrs, acc, mask=mask_1d) - - if SAVE_INTERMEDIATE: - # Save the window state after consuming this token - # Layout: [seq(cache line), step, dim, win(K-1)] - base_ptr = (intermediate_conv_window_ptr + - conv_state_batch_coord * stride_inter_seq + - idx_token * stride_inter_step + - idx_feats * stride_inter_dim) - if KERNEL_WIDTH >= 2: - tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) - if KERNEL_WIDTH >= 3: - tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) - if KERNEL_WIDTH >= 4: - tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) - - -def causal_conv1d_update_npu( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - query_start_loc: torch.Tensor | None = None, - max_query_len: int = -1, - intermediate_conv_window: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID, - metadata=None, - validate_data=False, -): - """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] - conv_state: (..., dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) - """ - if query_start_loc is not None: - return original_causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - activation=activation, - conv_state_indices=conv_state_indices, - num_accepted_tokens=num_accepted_tokens, - query_start_loc=query_start_loc, - max_query_len=max_query_len, - validate_data=validate_data) - - if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM - assert pad_slot_id is not None - assert x.stride(1) == 1 - if isinstance(activation, bool): - activation = "silu" if activation is True else None - elif activation is not None: - assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert ( - conv_state.stride(-2) == 1 - ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch, ) == conv_state_indices.shape - - assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x - stride_w_dim, stride_w_width = weight.stride() - - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( - ) - stride_state_indices = (conv_state_indices.stride(0) - if conv_state_indices is not None else 0) - state_len = width - 1 + (seqlen - 1) # effective state_len needed - np2_statelen = triton.next_power_of_2(state_len) - - def grid(META): - return ( - batch, - triton.cdiv(dim, META["BLOCK_N"]), - ) - - # prepare intermediate buffer strides if provided - if intermediate_conv_window is not None: - stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( - intermediate_conv_window.stride(0), - intermediate_conv_window.stride(1), - intermediate_conv_window.stride(2), - intermediate_conv_window.stride(3), - ) - else: - stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 - - _causal_conv1d_update_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_state, - cache_seqlens, - conv_state_indices, - num_accepted_tokens, - intermediate_conv_window - if intermediate_conv_window is not None else x, - out, - # Matrix dimensions - batch, - dim, - seqlen, - state_len, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_state_indices, - stride_inter_seq, - stride_inter_step, - stride_inter_dim, - stride_inter_win, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=128, - SAVE_INTERMEDIATE=intermediate_conv_window is not None, - ) - if unsqueeze: - out = out.squeeze(-1) - return out