diff --git a/tests/e2e/nightly/ops/triton/test_causal_conv1d.py b/tests/e2e/nightly/ops/triton/test_causal_conv1d.py index fe65364f..fe4eac0f 100644 --- a/tests/e2e/nightly/ops/triton/test_causal_conv1d.py +++ b/tests/e2e/nightly/ops/triton/test_causal_conv1d.py @@ -6,6 +6,8 @@ import torch.nn.functional as F from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID, causal_conv1d_fn) +from vllm_ascend.ops.triton.mamba.causal_conv1d import \ + causal_conv1d_update_npu as causal_conv1d_update def validate_cmp(y_cal, y_ref, dtype, device='npu'): @@ -156,17 +158,13 @@ def causal_conv1d_fn_pytorch( @pytest.mark.parametrize('has_initial_state', [False, True]) -@pytest.mark.parametrize('itype', - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('silu_activation', [False, True]) -@pytest.mark.parametrize('has_bias', [False, True]) -@pytest.mark.parametrize('seq_len', [[ - 1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, - 2048, 4096 -]]) +@pytest.mark.parametrize('itype', [torch.bfloat16]) +@pytest.mark.parametrize('silu_activation', [True]) +@pytest.mark.parametrize('has_bias', [True]) +@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]]) @pytest.mark.parametrize('extra_state_len', [0, 2]) -@pytest.mark.parametrize('width', [2, 3, 4]) -@pytest.mark.parametrize('dim', [64, 4160]) +@pytest.mark.parametrize('width', [2, 4]) +@pytest.mark.parametrize('dim', [4160]) def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias, silu_activation, itype, has_initial_state): @@ -227,4 +225,137 @@ def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias, query_start_loc=query_start_loc) validate_cmp(out, out_ref, itype) - validate_cmp(conv_states, conv_states_ref, itype) \ No newline at end of file + validate_cmp(conv_states, conv_states_ref, itype) + + +def causal_conv1d_update_ref(x, + conv_state, + weight, + bias=None, + activation=None, + cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to( + weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange( + -(width - 1), 0, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand( + -1, dim, -1)) + x_new = torch.cat([conv_state.gather(2, width_idx), x], + dim=-1).to(weight.dtype) + copy_idx = torch.arange( + seqlen, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, + state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, + groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 3]) +@pytest.mark.parametrize("width", [3, 4]) +@pytest.mark.parametrize("dim", [2048 + 16, 4096]) +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch_size", [3, 64]) +def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, + width, seqlen, has_bias, + silu_activation, itype): + device = "npu" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + # total_entries = number of cache line + total_entries = 10 * batch_size + + # x will be (batch, dim, seqlen) with contiguous along dim-axis + x = torch.randn(padded_batch_size, seqlen, dim, device=device, + dtype=itype).transpose(1, 2) + + x_ref = x.clone() + + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat( + [ + conv_state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) + + # conv_state will be (cache_lines, dim, state_len) + # with contiguous along dim-axis + conv_state = torch.randn(total_entries, + width - 1, + dim, + device=device, + dtype=itype).transpose(1, 2) + + conv_state_for_padding_test = conv_state.clone() + + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], + conv_state_ref, + weight, + bias, + activation=activation) + + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.equal(conv_state[unused_states_bool], + conv_state_for_padding_test[unused_states_bool]) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index fafa00c2..38b838d8 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -441,317 +441,407 @@ def causal_conv1d_fn(x: torch.Tensor, return out -# TODO copied from vllm and it needs to be optimized -@triton.jit() -def _original_causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - conv_state_indices_ptr, - num_accepted_tokens_ptr, - query_start_loc_ptr, # (batch + 1) - block_idx_last_scheduled_token, # (batch,) - initial_state_idx, # (batch,) - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_APC_ENABLED: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, +@triton.jit +def _causal_conv1d_update_kernel_npu_tiled( + # Pointers + x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, # (num_cache_lines, dim, state_len) + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + o_ptr, # same shape as x_ptr + batch: tl.int32, + dim: tl.constexpr, + seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen + state_len: tl.constexpr, # effective state_len computed in wrapper + num_cache_lines: tl.constexpr, + + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + + # others + pad_slot_id: tl.constexpr, + + # Meta + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, # <= 6 + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + + # tiling + BLOCK_N: tl.constexpr, # channel tile (C_TILE) + B_TILE: tl.constexpr, # batch tile + T_CHUNK: tl.constexpr, # token chunk for state update ): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return + # program ids + pid_b = tl.program_id(0) # batch-tile id + pid_c = tl.program_id(1) # channel-tile id - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_APC_ENABLED: - # Get the state from the initial_state_idx - conv_state_init = tl.load(initial_state_idx + idx_seq) - current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) - else: - conv_state_init = 0 - current_last_index = 0 - - # cache_idx - conv_states_input_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - conv_state_init).to(tl.int64) - - if USE_PAD_SLOT: # noqa - if conv_states_input_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_VARLEN: - query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) - query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to( - tl.int64) - # revise state_len and seqlen - state_len = state_len - (seqlen - - (query_end_index - query_start_index)) - seqlen = query_end_index - query_start_index - x_offset = query_start_index * stride_x_token - o_offset = query_start_index * stride_o_token - else: - query_start_index = idx_seq * seqlen - query_end_index = query_start_index + seqlen - x_offset = idx_seq * stride_x_seq - o_offset = idx_seq * stride_o_seq - - if query_start_index == query_end_index: - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = ( - tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1) - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_states_input_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) + # channel indices for this program + idx_feats = pid_c * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N] mask_w = idx_feats < dim - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + # preload weights once per program (shared by B_TILE sequences) + w_base = w_ptr + idx_feats * stride_w_dim + # define to avoid "undefined" in branches + w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + if KERNEL_WIDTH >= 1: + w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) if KERNEL_WIDTH >= 6: - conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] - col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, + other=0.0).to(tl.float32) - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # With speculative decoding, the conv_state updates works in a sliding - # window manner, at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * - stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_states_input_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] - - x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - # Get the state from the initial_state_idx - # cache_idx - conv_states_offset = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices + - current_last_index).to(tl.int64) - conv_state_ptrs_target = ( - conv_state_ptr + - (conv_states_offset * stride_conv_state_seq) # Offset from seq - + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] - idx_tokens * stride_conv_state_tok)[:, None] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator + # bias vector once per program if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, + other=0.0).to(tl.float32) else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32) - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 5: - w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor - w_col4 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 6: - w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor - w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + # token index vector for chunked copy + tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK] - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim + # process B_TILE sequences inside the same program instance + for bi in tl.static_range(0, B_TILE): + b = pid_b * B_TILE + bi # scalar tl.int32 + lane_active = b < batch # scalar predicate - # STEP 5: compute each token - for idx_token in tl.range(seqlen): - acc = acc_preload + # ------------------------- + # APC mapping (optional) + # ------------------------- + if IS_APC_ENABLED: + conv_state_init = tl.load(initial_state_idx + b, + mask=lane_active, + other=0).to(tl.int32) + current_last_index = tl.load(block_idx_last_scheduled_token + b, + mask=lane_active, + other=0).to(tl.int32) + else: + conv_state_init = tl.full((), 0, tl.int32) + current_last_index = tl.full((), 0, tl.int32) - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): + # input cache line + conv_states_input_coord = tl.load(conv_state_indices_ptr + + b * stride_state_indices + + conv_state_init, + mask=lane_active, + other=0).to(tl.int64) + + if USE_PAD_SLOT: + lane_active = lane_active & (conv_states_input_coord + != pad_slot_id) + + # ------------------------- + # varlen (optional): revise seqlen_run and state_len_run like original kernel does + # ------------------------- + if IS_VARLEN: + qs = tl.load(query_start_loc_ptr + b, mask=lane_active, + other=0).to(tl.int64) + qe = tl.load(query_start_loc_ptr + (b + 1), + mask=lane_active, + other=0).to(tl.int64) + seqlen_run = (qe - qs).to(tl.int32) + # revise effective state_len for shorter sequences (same formula as original) + state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32) + x_offset = (qs * stride_x_token).to(tl.int64) + o_offset = (qs * stride_o_token).to(tl.int64) + else: + seqlen_run = tl.full((), seqlen, tl.int32) + state_len_run = tl.full((), state_len, tl.int32) + x_offset = (b * stride_x_seq).to(tl.int64) + o_offset = (b * stride_o_seq).to(tl.int64) + + # empty sequence -> skip (avoid early return because other lanes in tile) + lane_active = lane_active & (seqlen_run > 0) + + # ------------------------- + # spec decoding offset (optional) + # ------------------------- + if IS_SPEC_DECODING: + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + b, mask=lane_active, + other=1).to(tl.int64) - 1) + shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode + else: + conv_state_token_offset = tl.full((), 0, tl.int64) + shift = seqlen_run # normal mode shift by seqlen + + # ------------------------- + # STEP 1: read initial history cols BEFORE state update (out==x safe) + # ------------------------- + conv_states_base = (conv_state_ptr + + conv_states_input_coord * stride_conv_state_seq + + idx_feats * stride_conv_state_dim) + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + + # define history vectors as zeros then load conditionally + col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + if KERNEL_WIDTH >= 2: + col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 3: + col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 4: + col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 5: + col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + if KERNEL_WIDTH >= 6: + col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + + # ------------------------- + # STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block) + # Semantics: conv_state <- concat(old_state, x)[-state_len_run:]. + # - If seqlen_run >= state_len_run: dst[:] = x[seqlen_run - state_len_run : seqlen_run] + # - Else: keep = state_len_run - seqlen_run, + # dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run] + # ------------------------- + # output cache line + conv_states_offset = tl.load(conv_state_indices_ptr + + b * stride_state_indices + + current_last_index, + mask=lane_active, + other=0).to(tl.int64) + + use_shift = (seqlen_run < state_len_run) + use_tail = (seqlen_run >= state_len_run) + + zero_i32 = tl.full((), 0, tl.int32) + keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), + zero_i32).to(tl.int32) + tail_start = tl.where(use_tail, (seqlen_run - state_len_run), + zero_i32).to(tl.int32) + + # base pointers + state_src_base = (conv_state_ptr + + conv_states_input_coord * stride_conv_state_seq + + conv_state_token_offset * stride_conv_state_tok + + idx_feats * stride_conv_state_dim) + state_dst_base = (conv_state_ptr + + conv_states_offset * stride_conv_state_seq + + idx_feats * stride_conv_state_dim) + + x_base = x_ptr + x_offset + idx_feats * stride_x_dim + + # A) shift old state into dst[0:keep_shift) (only when seqlen_run < state_len_run) + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] + src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK] + m_tok = use_shift & (dst_tok < keep_shift) & ( + src_tok < state_len_run) & (dst_tok < state_len_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( + conv_states_input_coord + < num_cache_lines) & (conv_states_offset < num_cache_lines) + + src_ptrs = state_src_base[ + None, :] + src_tok[:, None] * stride_conv_state_tok + dst_ptrs = state_dst_base[ + None, :] + dst_tok[:, None] * stride_conv_state_tok + vals = tl.load(src_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, vals, mask=m) + + # B) append x into dst[keep_shift : keep_shift+seqlen_run) (only when seqlen_run < state_len_run) + for t0 in tl.static_range(0, seqlen, T_CHUNK): + x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] + dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK] + m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok + < state_len_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( + conv_states_offset < num_cache_lines) + + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = state_dst_base[ + None, :] + dst_tok[:, None] * stride_conv_state_tok + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + + # C) if seqlen_run >= state_len_run, overwrite dst with the tail of x + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] + x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK] + m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( + conv_states_offset < num_cache_lines) + + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = state_dst_base[ + None, :] + dst_tok[:, None] * stride_conv_state_tok + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + + # ------------------------- + # STEP 3/4/5: causal conv1d (+ optional SiLU) and store output + # This is original STEP3~5, but per-lane and without debug_barrier. + # ------------------------- + x_base_1d = x_base + o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim + + # accumulator preload (bias) + acc_preload = acc_bias + + # compute each token; keep tl.range so varlen can use seqlen_run as runtime trip count (like original) + for idx_token in tl.range(seqlen_run): + acc = acc_preload + + # same selection logic as original (unrolled by KERNEL_WIDTH) + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 1: + # only x[t] * w0 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + matrix_w = w_col0 + elif KERNEL_WIDTH == 2: + if j == 1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, + mask=lane_active & mask_w, + other=0.0).to(tl.float16) + + acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N] + + # roll history window if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = matrix_x elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = matrix_x elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = matrix_x elif KERNEL_WIDTH == 5: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x elif KERNEL_WIDTH == 6: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - matrix_x = col4 - elif j == 5: - matrix_w = w_col5 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x - acc += matrix_x * matrix_w # [BLOCK_N] + if SILU_ACTIVATION: + acc = acc / (1.0 + tl.exp(-acc)) - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - elif KERNEL_WIDTH == 5: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = matrix_x - elif KERNEL_WIDTH == 6: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = col4 - col4 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim - ) # token-index # feature-index - o_ptrs = (o_ptr + o_offset + idx_token * stride_o_token + - (idx_feats * stride_o_dim)) - - tl.store(o_ptrs, acc, mask=mask_1d) + # store output + o_ptrs = o_base_1d + idx_token * stride_o_token + tl.store(o_ptrs, acc, mask=lane_active & mask_w) -# TODO copied from vllm and it needs to be optimized -def original_causal_conv1d_update( +def causal_conv1d_update_npu( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, @@ -818,6 +908,7 @@ def original_causal_conv1d_update( if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) + if query_start_loc is None: batch, dim, seqlen = x.shape else: @@ -825,36 +916,25 @@ def original_causal_conv1d_update( batch = conv_state_indices.size(0) dim = x.size(1) seqlen = max_query_len + _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() + num_cache_lines, _, state_len_total = conv_state.size() if validate_data: assert dim == weight.size(0) - assert conv_state.stride(-2) == 1, ( - f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - ) - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch, ) == conv_state_indices.shape - + assert conv_state.stride(-2) == 1 + assert state_len_total >= width - 1 assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this + assert weight.stride(1) == 1 - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + # overwrite-on-x strategy same as original out = x - stride_w_dim, stride_w_width = weight.stride() + stride_w_dim, stride_w_width = weight.stride() if query_start_loc is None: - # X (batch, dim, seqlen) stride_x_seq, stride_x_dim, stride_x_token = x.stride() stride_o_seq, stride_o_dim, stride_o_token = out.stride() else: - # X (dim, cu_seqlen) stride_x_token, stride_x_dim = x.stride() stride_x_seq = 0 stride_o_token, stride_o_dim = out.stride() @@ -862,22 +942,46 @@ def original_causal_conv1d_update( stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) - stride_state_indices = (conv_state_indices.stride(0) - if conv_state_indices is not None else 0) + stride_state_indices = conv_state_indices.stride( + 0) if conv_state_indices is not None else 0 + + # effective state_len exactly as original if num_accepted_tokens is not None: - state_len = width - 1 + (seqlen - 1) # effective state_len needed + eff_state_len = width - 1 + (seqlen - 1) else: - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) + eff_state_len = width - 1 + np2_statelen = triton.next_power_of_2(eff_state_len) + + # -------- tiling heuristic-------- + #keep program count around ~[80..160] + # vector core 40 + # TODO: use driver to get the vector core num + CORE_HINT = 40 + # channel tile: 512 when dim large (reduce tasks), else 256 + block_n = 512 if dim >= 512 else 256 + g = triton.cdiv(dim, block_n) + target = 2 * CORE_HINT # ~80 + b_tile_raw = max(1, (batch * g + target - 1) // target) + # clamp to small set + if b_tile_raw <= 1: + b_tile = 1 + elif b_tile_raw <= 2: + b_tile = 2 + elif b_tile_raw <= 4: + b_tile = 4 + else: + b_tile = 8 + + # token chunk based on block_n (32KB UB idea); conservative + t_chunk = 20 if block_n == 512 else 48 def grid(META): return ( - batch, + triton.cdiv(batch, META["B_TILE"]), triton.cdiv(dim, META["BLOCK_N"]), ) - _original_causal_conv1d_update_kernel[grid]( - # Pointers to matrices + _causal_conv1d_update_kernel_npu_tiled[grid]( x, weight, bias, @@ -888,13 +992,11 @@ def original_causal_conv1d_update( block_idx_last_scheduled_token, initial_state_idx, out, - # Matrix dimensions batch, dim, seqlen, - state_len, + eff_state_len, num_cache_lines, - # stride stride_x_seq, stride_x_dim, stride_x_token, @@ -907,9 +1009,7 @@ def original_causal_conv1d_update( stride_o_seq, stride_o_dim, stride_o_token, - # others pad_slot_id, - # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], @@ -918,431 +1018,11 @@ def original_causal_conv1d_update( IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=256, + BLOCK_N=block_n, + B_TILE=b_tile, + T_CHUNK=t_chunk, ) + if unsqueeze: out = out.squeeze(-1) return out.to(original_x_dtype) - - -@triton.jit() -def _causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - cache_seqlens_ptr, # circular buffer - conv_state_indices_ptr, - num_accepted_tokens_ptr, - intermediate_conv_window_ptr, - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_inter_seq: tl.constexpr, - stride_inter_step: tl.constexpr, - stride_inter_dim: tl.constexpr, - stride_inter_win: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, - SAVE_INTERMEDIATE: tl.constexpr, -): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return - - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) - else: - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + - idx_seq) - 1 - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) - mask_w = idx_feats < dim - - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok - if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # The conv_state updates works in a sliding window manner, - # at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + 1) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] - - x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = (conv_state_base + - (idx_tokens * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) - - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim - - # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - # mask_1d = (idx_token < seqlen) & ( - # idx_feats < dim - # ) # token-index # feature-index - maskL = idx_feats < dim - maskR = tl.full(maskL.shape, False, tl.int1) - mask_1d = tl.where(idx_token < seqlen, maskL, maskR) - - o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + - idx_token * stride_o_token + (idx_feats * stride_o_dim)) - - tl.store(o_ptrs, acc, mask=mask_1d) - - if SAVE_INTERMEDIATE: - # Save the window state after consuming this token - # Layout: [seq(cache line), step, dim, win(K-1)] - base_ptr = (intermediate_conv_window_ptr + - conv_state_batch_coord * stride_inter_seq + - idx_token * stride_inter_step + - idx_feats * stride_inter_dim) - if KERNEL_WIDTH >= 2: - tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) - if KERNEL_WIDTH >= 3: - tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) - if KERNEL_WIDTH >= 4: - tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) - - -def causal_conv1d_update_npu( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - query_start_loc: torch.Tensor | None = None, - max_query_len: int = -1, - intermediate_conv_window: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID, - metadata=None, - validate_data=False, -): - """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] - conv_state: (..., dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) - """ - if query_start_loc is not None: - return original_causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - activation=activation, - conv_state_indices=conv_state_indices, - num_accepted_tokens=num_accepted_tokens, - query_start_loc=query_start_loc, - max_query_len=max_query_len, - validate_data=validate_data) - - if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM - assert pad_slot_id is not None - assert x.stride(1) == 1 - if isinstance(activation, bool): - activation = "silu" if activation is True else None - elif activation is not None: - assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert ( - conv_state.stride(-2) == 1 - ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch, ) == conv_state_indices.shape - - assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x - stride_w_dim, stride_w_width = weight.stride() - - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( - ) - stride_state_indices = (conv_state_indices.stride(0) - if conv_state_indices is not None else 0) - state_len = width - 1 + (seqlen - 1) # effective state_len needed - np2_statelen = triton.next_power_of_2(state_len) - - def grid(META): - return ( - batch, - triton.cdiv(dim, META["BLOCK_N"]), - ) - - # prepare intermediate buffer strides if provided - if intermediate_conv_window is not None: - stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( - intermediate_conv_window.stride(0), - intermediate_conv_window.stride(1), - intermediate_conv_window.stride(2), - intermediate_conv_window.stride(3), - ) - else: - stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 - - _causal_conv1d_update_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_state, - cache_seqlens, - conv_state_indices, - num_accepted_tokens, - intermediate_conv_window - if intermediate_conv_window is not None else x, - out, - # Matrix dimensions - batch, - dim, - seqlen, - state_len, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_state_indices, - stride_inter_seq, - stride_inter_step, - stride_inter_dim, - stride_inter_win, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=128, - SAVE_INTERMEDIATE=intermediate_conv_window is not None, - ) - if unsqueeze: - out = out.squeeze(-1) - return out