[OPS] support triton causal_conv1d_fn ops (#4119)

### What this PR does / why we need it?
Support triton causal_conv1d_fn ops.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: QilaiZhang <245706640@qq.com>
This commit is contained in:
QilaiZhang
2025-12-11 15:52:39 +08:00
committed by GitHub
parent eac72f5f23
commit 78bf211539
3 changed files with 632 additions and 93 deletions

View File

@@ -0,0 +1,230 @@
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
causal_conv1d_fn)
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
y_ref = y_ref.to(device)
if dtype == torch.float16:
torch.testing.assert_close(y_ref,
y_cal,
rtol=3e-03,
atol=1e-02,
equal_nan=True)
elif dtype == torch.bfloat16:
torch.testing.assert_close(y_ref,
y_cal,
rtol=1e-02,
atol=1e-02,
equal_nan=True)
elif dtype == torch.float32:
torch.testing.assert_close(y_ref,
y_cal,
rtol=1e-03,
atol=4e-03,
equal_nan=True)
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
assert torch.equal(y_cal, y_ref)
elif dtype == torch.bool:
assert torch.equal(y_cal, y_ref)
else:
raise ValueError(
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn_pytorch(
x: torch.Tensor,
weight: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: torch.Tensor,
has_initial_state: torch.Tensor,
conv_states: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
width = weight.shape[1]
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]][..., :(
width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
if has_initial_state[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype',
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [False, True])
@pytest.mark.parametrize('has_bias', [False, True])
@pytest.mark.parametrize('seq_len', [[
1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134,
2048, 4096
]])
@pytest.mark.parametrize('extra_state_len', [0, 2])
@pytest.mark.parametrize('width', [2, 3, 4])
@pytest.mark.parametrize('dim', [64, 4160])
def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
silu_activation, itype, has_initial_state):
torch.random.manual_seed(0)
device = "npu"
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
state_len = width - 1 + extra_state_len
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
weight = torch.randn(dim, width, device=device, dtype=itype)
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
device=device,
dtype=torch.int32),
dim=0)
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
device=device,
dtype=torch.bool)
activation = None if not silu_activation else "silu"
if has_initial_state:
conv_states = torch.randn((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.randn(
(num_seq, state_len, dim), device=device,
dtype=itype).transpose(-1, -2).copy_(conv_states)
else:
conv_states = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype)
else:
bias = None
out_ref = causal_conv1d_fn_pytorch(
x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states_ref,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
out = causal_conv1d_fn(x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
validate_cmp(out, out_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)

View File

@@ -1,4 +1,4 @@
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py # adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
@@ -10,78 +10,289 @@
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
PAD_SLOT_ID = -1 PAD_SLOT_ID = -1
def causal_conv1d_ref( @triton.jit()
x: torch.Tensor, def _causal_conv1d_fwd_kernel( # continuous batching
weight: torch.Tensor, # Pointers to matrices
bias: Optional[torch.Tensor] = None, x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
initial_states: Optional[torch.Tensor] = None, w_ptr, # (dim, width)
return_final_states: bool = False, bias_ptr,
final_states_out: Optional[torch.Tensor] = None, conv_states_ptr,
activation: Optional[str] = "silu", conv_state_indices_ptr,
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
o_ptr, # (dim, seqlen)
# Matrix dimensions
dim: tl.constexpr,
state_len: int,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.constexpr, # stride to get to next token
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
stride_w_width: tl.constexpr, # stride to get to next width-axis value
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
): ):
""" # single-sequence id
x: (batch, dim, seqlen) idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
weight: (dim, width) chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None: # BLOCK_N elements along the feature-dimension (channel)
out = F.conv1d(x, idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
weight.unsqueeze(1),
bias, sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
padding=width - 1, sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
groups=dim) # find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
else: else:
x = torch.cat([initial_states, x], dim=-1) # cache_idx
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) conv_state_batch_coord = idx_seq
out = out[..., :seqlen]
if return_final_states: if USE_PAD_SLOT: # noqa
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( if conv_state_batch_coord == pad_slot_id:
dtype_in) # (batch, dim, width - 1) # not processing as this is not the actual sequence
if final_states_out is not None: return
final_states_out[..., :(width - 1)].copy_(final_states) conv_states_base = conv_states_ptr + (
conv_state_batch_coord * stride_conv_state_seq) + (
idx_feats * stride_conv_state_dim) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq)
mask_dim = idx_feats < dim
# read prior-token data from `x`
offset_x = token_offset - KERNEL_WIDTH + 1
if KERNEL_WIDTH >= 2:
x0_ptrs = x_base + offset_x * stride_x_token
x0 = tl.load(x0_ptrs, mask_dim & (offset_x > 0))
if KERNEL_WIDTH >= 3:
x1_ptrs = x0_ptrs + 1 * stride_x_token
x1 = tl.load(x1_ptrs, mask_dim & (offset_x + 1 > 0))
if KERNEL_WIDTH >= 4:
x2_ptrs = x1_ptrs + 1 * stride_x_token
x2 = tl.load(x2_ptrs, mask_dim & (offset_x + 2 > 0))
if load_init_state & (chunk_offset == 0):
# load from conv_states
offset_conv_state = state_len - KERNEL_WIDTH + 1
if KERNEL_WIDTH >= 2:
x0_ptrs = conv_states_base + offset_conv_state * stride_conv_state_tok
x0 = tl.load(x0_ptrs, mask_dim, 0.0)
if KERNEL_WIDTH >= 3:
x1_ptrs = x0_ptrs + 1 * stride_conv_state_tok
x1 = tl.load(x1_ptrs, mask_dim)
if KERNEL_WIDTH >= 4:
x2_ptrs = x1_ptrs + 1 * stride_conv_state_tok
x2 = tl.load(x2_ptrs, mask_dim)
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else: else:
final_states_out = final_states acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out) x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
# PRE-LOAD WEIGHTS
mask_dim = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w0 = tl.load(w_ptrs, mask_dim, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w1 = tl.load(w_ptrs, mask_dim, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w2 = tl.load(w_ptrs, mask_dim, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w3 = tl.load(w_ptrs, mask_dim, other=0.0)
for idx_token in tl.static_range(BLOCK_M):
acc = acc_preload
mask_1d = (idx_token
< segment_len) & mask_dim # token-index # feature-index
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
x = tl.load(x_ptrs_1d, mask=mask_1d)
if KERNEL_WIDTH == 2:
acc += x0 * w0 + x * w1
x0 = x
elif KERNEL_WIDTH == 3:
acc += x0 * w0 + x1 * w1 + x * w2
x0 = x1
x1 = x
elif KERNEL_WIDTH == 4:
acc += x0 * w0 + x1 * w1 + x2 * w2 + x * w3
x0 = x1
x1 = x2
x2 = x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
) * stride_o_token + (idx_feats * stride_o_dim)
tl.store(o_ptrs, acc, mask=mask_1d)
# update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
# just read from 'x'
# copy 'x' data to conv_state
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
idx_tokens_last = (seqlen - state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (
(sequence_start_index + idx_tokens_last) *
stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = ((idx_tokens_last >= 0)[:, None] &
(idx_tokens_last < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = conv_states_base[None, :] + (
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier()
tl.store(conv_states_ptrs_target, new_conv_state, mask)
elif load_init_state:
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_source = (
conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else:
# update conv_state by shifting left, BUT
# set cols prior to 'x' as zeros + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier()
tl.store(conv_states_ptrs_target, new_conv_state, mask)
def causal_conv1d_fn( def causal_conv1d_fn(x: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Union[torch.Tensor, None],
query_start_loc: Optional[torch.Tensor] = None, conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: Optional[torch.Tensor] = None, cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu", activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID, pad_slot_id: int = PAD_SLOT_ID,
metadata: Optional[Any] = None, metadata: Optional[Any] = None,
): validate_data=False):
""" """support varlen + continuous batching when x is 2D tensor
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen x: (dim,cu_seq_len)
cu_seq_len = total tokens of all seqs in that batch
sequences are concatenated from left to right for varlen sequences are concatenated from left to right for varlen
weight: (dim, width) weight: (dim, width)
bias: (dim,) conv_states: (...,dim,width - 1) itype
updated inplace if provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
]
query_start_loc: (batch + 1) int32 query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0. the batch, used to index into sequence. prepended by 0.
if
x = [5, 1, 1, 1] <- continuous batching (batch=4)
then
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
the ending index of the last sequence
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]), for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17) x.shape=(dim,17)
cache_indices: (batch) int32 cache_indices: (batch) int32
@@ -90,46 +301,144 @@ def causal_conv1d_fn(
has_initial_state: (batch) bool has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial indicates whether should the kernel take the current state as initial
state for the calculations state for the calculations
conv_states: (...,dim,width - 1) itype [single boolean for each sequence in the batch: True or False]
updated inplace if provided bias: (dim,)
activation: either None or "silu" or "swish" activation: either None or "silu" or "swish" or True
pad_slot_id: int pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded if cache_indices is passed, lets the kernel identify padded
entries that will not be processed, entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at in this case, the kernel will not process entries at
indices 0 and 3 indices 0 and 3
out: (batch, dim, seqlen) out: same shape as `x`
""" """
if activation not in [None, "silu", "swish"]: if isinstance(activation, bool) and activation:
raise NotImplementedError("activation must be None, silu, or swish") activation = "silu"
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out_ref = [] # Store original dtype to cast back at the end
out_ref_b = [] out = torch.empty_strided(x.size(),
seqlens = query_start_loc[1:] - query_start_loc[:-1] x.stride(),
seqlens = seqlens.tolist() dtype=x.dtype,
splits = torch.split(x, seqlens, dim=-1) device=x.device)
for i in range(len(seqlens)): dim, _ = x.shape
x_s = splits[i] _, width = weight.shape
if cache_indices[i] == PAD_SLOT_ID:
continue state_len = width - 1
out_ref_b.append( np2_statelen = triton.next_power_of_2(state_len)
causal_conv1d_ref(
x_s, padded_batch = query_start_loc.size(0) - 1
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
stride_w_width = weight.stride(1)
stride_istate_seq = 0
stride_istate_dim = 0
stride_istate_token = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
num_cache_lines = 0
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
stride_cache_indices = cache_indices.stride(
0) if cache_indices is not None else 0
if validate_data:
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
assert x.dim() == 2
assert width in [2, 3, 4]
assert query_start_loc is not None
assert query_start_loc.dim() == 1
assert x.stride(0) == 1 or x.stride(1) == 1
if bias is not None:
assert bias.dim() == 1
assert dim == bias.size(0)
if conv_states is not None:
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and conv_states.shape[2] >= width - 1)
assert stride_istate_dim == 1
if cache_indices is not None:
assert cache_indices.dim() == 1
assert padded_batch == cache_indices.size(0)
if has_initial_state is not None:
assert has_initial_state.size() == (padded_batch, )
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
BLOCK_M = 64
seqlens = query_start_loc.diff()
seq_blocks = -(-seqlens // BLOCK_M)
total_seq_blocks = seq_blocks.sum().item()
# tracking which seq-idx the Triton program is handling
batch_ptr = torch.repeat_interleave(
torch.arange(len(seq_blocks), device=x.device),
seq_blocks).to(torch.int32)
# tracking BLOCK_M-based index in the sequence the Triton program is handling
max_blocks = seq_blocks.max().item() if len(seq_blocks) > 0 else 0
arange = torch.arange(max_blocks, device=x.device)
mask = arange.unsqueeze(0) < seq_blocks.unsqueeze(1)
token_chunk_offset_ptr = arange.repeat(len(seq_blocks),
1)[mask].to(torch.int32)
BLOCK_N = 256
grid = (total_seq_blocks, triton.cdiv(dim, BLOCK_N))
with torch.npu.device(x.device.index):
_causal_conv1d_fwd_kernel[grid](
# Pointers to matrices
x,
weight, weight,
bias, bias,
activation=activation, conv_states,
return_final_states=True, cache_indices,
final_states_out=conv_states[cache_indices[i]].unsqueeze(0), has_initial_state,
initial_states=conv_states[cache_indices[i]] query_start_loc,
if has_initial_state[i] else None)) batch_ptr,
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) token_chunk_offset_ptr,
out_ref_tensor = torch.cat(out_ref, dim=0) out,
return out_ref_tensor # Matrix dimensions
dim,
state_len,
num_cache_lines,
# stride
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N)
return out
# TODO copied from vllm and it needs to be optimized # TODO copied from vllm and it needs to be optimized

View File

@@ -4,7 +4,7 @@ from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn
from vllm_ascend.ops.triton.fla.sigmoid_gating import \ from vllm_ascend.ops.triton.fla.sigmoid_gating import \
fused_recurrent_gated_delta_rule_fwd_kernel fused_recurrent_gated_delta_rule_fwd_kernel
from vllm_ascend.ops.triton.mamba.casual_conv1d import ( from vllm_ascend.ops.triton.mamba.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update_npu) causal_conv1d_fn, causal_conv1d_update_npu)
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu