diff --git a/tests/e2e/nightly/ops/triton/test_causal_conv1d.py b/tests/e2e/nightly/ops/triton/test_causal_conv1d.py new file mode 100644 index 00000000..fe65364f --- /dev/null +++ b/tests/e2e/nightly/ops/triton/test_causal_conv1d.py @@ -0,0 +1,230 @@ +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F + +from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID, + causal_conv1d_fn) + + +def validate_cmp(y_cal, y_ref, dtype, device='npu'): + y_cal = y_cal.to(device) + y_ref = y_ref.to(device) + if dtype == torch.float16: + torch.testing.assert_close(y_ref, + y_cal, + rtol=3e-03, + atol=1e-02, + equal_nan=True) + elif dtype == torch.bfloat16: + torch.testing.assert_close(y_ref, + y_cal, + rtol=1e-02, + atol=1e-02, + equal_nan=True) + elif dtype == torch.float32: + torch.testing.assert_close(y_ref, + y_cal, + rtol=1e-03, + atol=4e-03, + equal_nan=True) + elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32: + assert torch.equal(y_cal, y_ref) + elif dtype == torch.bool: + assert torch.equal(y_cal, y_ref) + else: + raise ValueError( + 'Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_fn_pytorch( + x: torch.Tensor, + weight: torch.Tensor, + query_start_loc: torch.Tensor, + cache_indices: torch.Tensor, + has_initial_state: torch.Tensor, + conv_states: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + width = weight.shape[1] + + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]][..., :( + width - 1)].unsqueeze(0), + initial_states=conv_states[cache_indices[i]][..., :(width - 1)] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor + + +@pytest.mark.parametrize('has_initial_state', [False, True]) +@pytest.mark.parametrize('itype', + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('silu_activation', [False, True]) +@pytest.mark.parametrize('has_bias', [False, True]) +@pytest.mark.parametrize('seq_len', [[ + 1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, + 2048, 4096 +]]) +@pytest.mark.parametrize('extra_state_len', [0, 2]) +@pytest.mark.parametrize('width', [2, 3, 4]) +@pytest.mark.parametrize('dim', [64, 4160]) +def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias, + silu_activation, itype, has_initial_state): + + torch.random.manual_seed(0) + + device = "npu" + cu_seqlen, num_seq = sum(seq_len), len(seq_len) + state_len = width - 1 + extra_state_len + + x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1) + weight = torch.randn(dim, width, device=device, dtype=itype) + query_start_loc = torch.cumsum(torch.tensor([0] + seq_len, + device=device, + dtype=torch.int32), + dim=0) + cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32) + has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq, + device=device, + dtype=torch.bool) + activation = None if not silu_activation else "silu" + + if has_initial_state: + conv_states = torch.randn((num_seq, state_len, dim), + device=device, + dtype=itype).transpose(-1, -2) + conv_states_ref = torch.randn( + (num_seq, state_len, dim), device=device, + dtype=itype).transpose(-1, -2).copy_(conv_states) + else: + conv_states = torch.zeros((num_seq, state_len, dim), + device=device, + dtype=itype).transpose(-1, -2) + conv_states_ref = torch.zeros((num_seq, state_len, dim), + device=device, + dtype=itype).transpose(-1, -2) + + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype) + else: + bias = None + + out_ref = causal_conv1d_fn_pytorch( + x, + weight, + bias=bias, + activation=activation, + conv_states=conv_states_ref, + has_initial_state=has_initial_state_tensor, + cache_indices=cache_indices, + query_start_loc=query_start_loc) + out = causal_conv1d_fn(x, + weight, + bias=bias, + activation=activation, + conv_states=conv_states, + has_initial_state=has_initial_state_tensor, + cache_indices=cache_indices, + query_start_loc=query_start_loc) + + validate_cmp(out, out_ref, itype) + validate_cmp(conv_states, conv_states_ref, itype) \ No newline at end of file diff --git a/vllm_ascend/ops/triton/mamba/casual_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py similarity index 67% rename from vllm_ascend/ops/triton/mamba/casual_conv1d.py rename to vllm_ascend/ops/triton/mamba/causal_conv1d.py index 79da996b..fafa00c2 100644 --- a/vllm_ascend/ops/triton/mamba/casual_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -1,4 +1,4 @@ -# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py +# adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # SPDX-License-Identifier: Apache-2.0 @@ -10,78 +10,289 @@ from typing import Any, Optional, Union import torch -import torch.nn.functional as F import triton import triton.language as tl PAD_SLOT_ID = -1 -def causal_conv1d_ref( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + conv_states_ptr, + conv_state_indices_ptr, + has_initial_states_ptr, + query_start_loc_ptr, + batch_ptr, + token_chunk_offset_ptr, + o_ptr, # (dim, seqlen) + # Matrix dimensions + dim: tl.constexpr, + state_len: int, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl.constexpr, # stride to get to next token + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_cache_indices: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, ): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape + # single-sequence id + idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) + chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) - if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + # base of the sequence + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices).to( + tl.int64) else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out[..., :(width - 1)].copy_(final_states) + # cache_idx + conv_state_batch_coord = idx_seq + + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = conv_states_ptr + ( + conv_state_batch_coord * stride_conv_state_seq) + ( + idx_feats * stride_conv_state_dim) # [BLOCK_N,] + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq) + + mask_dim = idx_feats < dim + + # read prior-token data from `x` + offset_x = token_offset - KERNEL_WIDTH + 1 + if KERNEL_WIDTH >= 2: + x0_ptrs = x_base + offset_x * stride_x_token + x0 = tl.load(x0_ptrs, mask_dim & (offset_x > 0)) + if KERNEL_WIDTH >= 3: + x1_ptrs = x0_ptrs + 1 * stride_x_token + x1 = tl.load(x1_ptrs, mask_dim & (offset_x + 1 > 0)) + if KERNEL_WIDTH >= 4: + x2_ptrs = x1_ptrs + 1 * stride_x_token + x2 = tl.load(x2_ptrs, mask_dim & (offset_x + 2 > 0)) + + if load_init_state & (chunk_offset == 0): + # load from conv_states + offset_conv_state = state_len - KERNEL_WIDTH + 1 + if KERNEL_WIDTH >= 2: + x0_ptrs = conv_states_base + offset_conv_state * stride_conv_state_tok + x0 = tl.load(x0_ptrs, mask_dim, 0.0) + if KERNEL_WIDTH >= 3: + x1_ptrs = x0_ptrs + 1 * stride_conv_state_tok + x1 = tl.load(x1_ptrs, mask_dim) + if KERNEL_WIDTH >= 4: + x2_ptrs = x1_ptrs + 1 * stride_conv_state_tok + x2 = tl.load(x2_ptrs, mask_dim) + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_dim = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w0 = tl.load(w_ptrs, mask_dim, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w1 = tl.load(w_ptrs, mask_dim, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w2 = tl.load(w_ptrs, mask_dim, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w3 = tl.load(w_ptrs, mask_dim, other=0.0) + + for idx_token in tl.static_range(BLOCK_M): + acc = acc_preload + mask_1d = (idx_token + < segment_len) & mask_dim # token-index # feature-index + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + x = tl.load(x_ptrs_1d, mask=mask_1d) + + if KERNEL_WIDTH == 2: + acc += x0 * w0 + x * w1 + x0 = x + elif KERNEL_WIDTH == 3: + acc += x0 * w0 + x1 * w1 + x * w2 + x0 = x1 + x1 = x + elif KERNEL_WIDTH == 4: + acc += x0 * w0 + x1 * w1 + x2 * w2 + x * w3 + x0 = x1 + x1 = x2 + x2 = x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + + o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token + ) * stride_o_token + (idx_feats * stride_o_dim) + tl.store(o_ptrs, acc, mask=mask_1d) + + # update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange( + 0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = x_ptr + ( + (sequence_start_index + idx_tokens_last) * + stride_x_token)[:, None] + ( + idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + mask_x = ((idx_tokens_last >= 0)[:, None] & + (idx_tokens_last < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = conv_states_base[None, :] + ( + idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.debug_barrier() + tl.store(conv_states_ptrs_target, new_conv_state, mask) + elif load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ( + (idx_tokens_conv - VAL) * + stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & + (idx_tokens_conv - VAL < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = conv_states_base + ( + idx_tokens_conv * + stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return (out, None) if not return_final_states else (out, final_states_out) + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ( + (idx_tokens_conv - VAL) * + stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & + (idx_tokens_conv - VAL < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = conv_states_base + ( + idx_tokens_conv * + stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.debug_barrier() + tl.store(conv_states_ptrs_target, new_conv_state, mask) -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID, - metadata: Optional[Any] = None, -): - """ - x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen +def causal_conv1d_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata: Optional[Any] = None, + validate_data=False): + """support varlen + continuous batching when x is 2D tensor + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch sequences are concatenated from left to right for varlen weight: (dim, width) - bias: (dim,) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 @@ -90,46 +301,144 @@ def causal_conv1d_fn( has_initial_state: (batch) bool indicates whether should the kernel take the current state as initial state for the calculations - conv_states: (...,dim,width - 1) itype - updated inplace if provided - activation: either None or "silu" or "swish" + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim, seqlen) + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: same shape as `x` """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(-1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None + if isinstance(activation, bool) and activation: + activation = "silu" - out_ref = [] - out_ref_b = [] - seqlens = query_start_loc[1:] - query_start_loc[:-1] - seqlens = seqlens.tolist() - splits = torch.split(x, seqlens, dim=-1) + # Store original dtype to cast back at the end + out = torch.empty_strided(x.size(), + x.stride(), + dtype=x.dtype, + device=x.device) - for i in range(len(seqlens)): - x_s = splits[i] - if cache_indices[i] == PAD_SLOT_ID: - continue - out_ref_b.append( - causal_conv1d_ref( - x_s, - weight, - bias, - activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]].unsqueeze(0), - initial_states=conv_states[cache_indices[i]] - if has_initial_state[i] else None)) - out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) - out_ref_tensor = torch.cat(out_ref, dim=0) - return out_ref_tensor + dim, _ = x.shape + _, width = weight.shape + + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + padded_batch = query_start_loc.size(0) - 1 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + + stride_cache_indices = cache_indices.stride( + 0) if cache_indices is not None else 0 + + if validate_data: + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + assert x.dim() == 2 + assert width in [2, 3, 4] + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if conv_states is not None: + assert (num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and conv_states.shape[2] >= width - 1) + assert stride_istate_dim == 1 + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch, ) + assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + BLOCK_M = 64 + seqlens = query_start_loc.diff() + seq_blocks = -(-seqlens // BLOCK_M) + total_seq_blocks = seq_blocks.sum().item() + # tracking which seq-idx the Triton program is handling + batch_ptr = torch.repeat_interleave( + torch.arange(len(seq_blocks), device=x.device), + seq_blocks).to(torch.int32) + + # tracking BLOCK_M-based index in the sequence the Triton program is handling + max_blocks = seq_blocks.max().item() if len(seq_blocks) > 0 else 0 + arange = torch.arange(max_blocks, device=x.device) + mask = arange.unsqueeze(0) < seq_blocks.unsqueeze(1) + token_chunk_offset_ptr = arange.repeat(len(seq_blocks), + 1)[mask].to(torch.int32) + + BLOCK_N = 256 + grid = (total_seq_blocks, triton.cdiv(dim, BLOCK_N)) + + with torch.npu.device(x.device.index): + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + batch_ptr, + token_chunk_offset_ptr, + out, + # Matrix dimensions + dim, + state_len, + num_cache_lines, + # stride + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_cache_indices, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N) + + return out # TODO copied from vllm and it needs to be optimized diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index 92e9a8a9..af0909c1 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -4,7 +4,7 @@ from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_recurrent_gated_delta_rule_fwd_kernel -from vllm_ascend.ops.triton.mamba.casual_conv1d import ( +from vllm_ascend.ops.triton.mamba.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update_npu) vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu