[OPS] support triton causal_conv1d_fn ops (#4119)
### What this PR does / why we need it?
Support triton causal_conv1d_fn ops.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: QilaiZhang <245706640@qq.com>
This commit is contained in:
230
tests/e2e/nightly/ops/triton/test_causal_conv1d.py
Normal file
230
tests/e2e/nightly/ops/triton/test_causal_conv1d.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
|
||||
causal_conv1d_fn)
|
||||
|
||||
|
||||
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
||||
y_cal = y_cal.to(device)
|
||||
y_ref = y_ref.to(device)
|
||||
if dtype == torch.float16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=3e-03,
|
||||
atol=1e-02,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.bfloat16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=1e-02,
|
||||
atol=1e-02,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.float32:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=1e-03,
|
||||
atol=4e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
elif dtype == torch.bool:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_fn_pytorch(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
cache_indices: torch.Tensor,
|
||||
has_initial_state: torch.Tensor,
|
||||
conv_states: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
width = weight.shape[1]
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
if cache_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]][..., :(
|
||||
width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
|
||||
if has_initial_state[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||
@pytest.mark.parametrize('itype',
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('silu_activation', [False, True])
|
||||
@pytest.mark.parametrize('has_bias', [False, True])
|
||||
@pytest.mark.parametrize('seq_len', [[
|
||||
1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134,
|
||||
2048, 4096
|
||||
]])
|
||||
@pytest.mark.parametrize('extra_state_len', [0, 2])
|
||||
@pytest.mark.parametrize('width', [2, 3, 4])
|
||||
@pytest.mark.parametrize('dim', [64, 4160])
|
||||
def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
|
||||
silu_activation, itype, has_initial_state):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
device = "npu"
|
||||
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
|
||||
state_len = width - 1 + extra_state_len
|
||||
|
||||
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
|
||||
device=device,
|
||||
dtype=torch.int32),
|
||||
dim=0)
|
||||
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
|
||||
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
|
||||
device=device,
|
||||
dtype=torch.bool)
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
if has_initial_state:
|
||||
conv_states = torch.randn((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
conv_states_ref = torch.randn(
|
||||
(num_seq, state_len, dim), device=device,
|
||||
dtype=itype).transpose(-1, -2).copy_(conv_states)
|
||||
else:
|
||||
conv_states = torch.zeros((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
conv_states_ref = torch.zeros((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
out_ref = causal_conv1d_fn_pytorch(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
activation=activation,
|
||||
conv_states=conv_states_ref,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc)
|
||||
out = causal_conv1d_fn(x,
|
||||
weight,
|
||||
bias=bias,
|
||||
activation=activation,
|
||||
conv_states=conv_states,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc)
|
||||
|
||||
validate_cmp(out, out_ref, itype)
|
||||
validate_cmp(conv_states, conv_states_ref, itype)
|
||||
@@ -1,4 +1,4 @@
|
||||
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py
|
||||
# adapted from vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
@@ -10,78 +10,289 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
@triton.jit()
|
||||
def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# Pointers to matrices
|
||||
x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_states_ptr,
|
||||
conv_state_indices_ptr,
|
||||
has_initial_states_ptr,
|
||||
query_start_loc_ptr,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
o_ptr, # (dim, seqlen)
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
state_len: int,
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
|
||||
stride_x_token: tl.constexpr, # stride to get to next token
|
||||
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
|
||||
stride_w_width: tl.constexpr, # stride to get to next width-axis value
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_cache_indices: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
HAS_INITIAL_STATES: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
# single-sequence id
|
||||
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
|
||||
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
# BLOCK_N elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
|
||||
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
|
||||
# find the actual sequence length
|
||||
seqlen = sequence_end_index - sequence_start_index
|
||||
|
||||
token_offset = BLOCK_M * chunk_offset
|
||||
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||
|
||||
# base of the sequence
|
||||
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out[..., :(width - 1)].copy_(final_states)
|
||||
# cache_idx
|
||||
conv_state_batch_coord = idx_seq
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
conv_states_base = conv_states_ptr + (
|
||||
conv_state_batch_coord * stride_conv_state_seq) + (
|
||||
idx_feats * stride_conv_state_dim) # [BLOCK_N,]
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
|
||||
load_init_state = False
|
||||
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
|
||||
load_init_state = tl.load(has_initial_states_ptr + idx_seq)
|
||||
|
||||
mask_dim = idx_feats < dim
|
||||
|
||||
# read prior-token data from `x`
|
||||
offset_x = token_offset - KERNEL_WIDTH + 1
|
||||
if KERNEL_WIDTH >= 2:
|
||||
x0_ptrs = x_base + offset_x * stride_x_token
|
||||
x0 = tl.load(x0_ptrs, mask_dim & (offset_x > 0))
|
||||
if KERNEL_WIDTH >= 3:
|
||||
x1_ptrs = x0_ptrs + 1 * stride_x_token
|
||||
x1 = tl.load(x1_ptrs, mask_dim & (offset_x + 1 > 0))
|
||||
if KERNEL_WIDTH >= 4:
|
||||
x2_ptrs = x1_ptrs + 1 * stride_x_token
|
||||
x2 = tl.load(x2_ptrs, mask_dim & (offset_x + 2 > 0))
|
||||
|
||||
if load_init_state & (chunk_offset == 0):
|
||||
# load from conv_states
|
||||
offset_conv_state = state_len - KERNEL_WIDTH + 1
|
||||
if KERNEL_WIDTH >= 2:
|
||||
x0_ptrs = conv_states_base + offset_conv_state * stride_conv_state_tok
|
||||
x0 = tl.load(x0_ptrs, mask_dim, 0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
x1_ptrs = x0_ptrs + 1 * stride_conv_state_tok
|
||||
x1 = tl.load(x1_ptrs, mask_dim)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
x2_ptrs = x1_ptrs + 1 * stride_conv_state_tok
|
||||
x2 = tl.load(x2_ptrs, mask_dim)
|
||||
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
acc_preload = tl.load(bias, mask=mask_bias,
|
||||
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
|
||||
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
|
||||
|
||||
# PRE-LOAD WEIGHTS
|
||||
mask_dim = idx_feats < dim
|
||||
if KERNEL_WIDTH >= 2:
|
||||
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
||||
w0 = tl.load(w_ptrs, mask_dim, other=0.0)
|
||||
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
||||
w1 = tl.load(w_ptrs, mask_dim, other=0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
||||
w2 = tl.load(w_ptrs, mask_dim, other=0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||
w3 = tl.load(w_ptrs, mask_dim, other=0.0)
|
||||
|
||||
for idx_token in tl.static_range(BLOCK_M):
|
||||
acc = acc_preload
|
||||
mask_1d = (idx_token
|
||||
< segment_len) & mask_dim # token-index # feature-index
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
x = tl.load(x_ptrs_1d, mask=mask_1d)
|
||||
|
||||
if KERNEL_WIDTH == 2:
|
||||
acc += x0 * w0 + x * w1
|
||||
x0 = x
|
||||
elif KERNEL_WIDTH == 3:
|
||||
acc += x0 * w0 + x1 * w1 + x * w2
|
||||
x0 = x1
|
||||
x1 = x
|
||||
elif KERNEL_WIDTH == 4:
|
||||
acc += x0 * w0 + x1 * w1 + x2 * w2 + x * w3
|
||||
x0 = x1
|
||||
x1 = x2
|
||||
x2 = x
|
||||
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
|
||||
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
|
||||
) * stride_o_token + (idx_feats * stride_o_dim)
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
# update conv_state with new data [only by the Triton program handles chunk_offset=0]
|
||||
if chunk_offset == 0:
|
||||
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
|
||||
# just read from 'x'
|
||||
# copy 'x' data to conv_state
|
||||
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
|
||||
idx_tokens_last = (seqlen - state_len) + tl.arange(
|
||||
0, NP2_STATELEN) # [BLOCK_M]
|
||||
x_ptrs = x_ptr + (
|
||||
(sequence_start_index + idx_tokens_last) *
|
||||
stride_x_token)[:, None] + (
|
||||
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
|
||||
mask_x = ((idx_tokens_last >= 0)[:, None] &
|
||||
(idx_tokens_last < seqlen)[:, None] &
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
conv_states_ptrs_target = conv_states_base[None, :] + (
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||
< dim)[None, :]
|
||||
tl.debug_barrier()
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
elif load_init_state:
|
||||
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
conv_states_ptrs_source = (
|
||||
conv_states_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
|
||||
x_ptrs = x_base[None, :] + (
|
||||
(idx_tokens_conv - VAL) *
|
||||
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
|
||||
(idx_tokens_conv - VAL < seqlen)[:, None] &
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
tl.debug_barrier()
|
||||
new_conv_state = tl.where(
|
||||
mask, conv_state, loaded_x
|
||||
) # BUG in 'tl.where' which requires a barrier before this
|
||||
conv_states_ptrs_target = conv_states_base + (
|
||||
idx_tokens_conv *
|
||||
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||
< dim)[None, :]
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
else:
|
||||
# update conv_state by shifting left, BUT
|
||||
# set cols prior to 'x' as zeros + cols from 'x'
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
VAL = state_len - seqlen
|
||||
|
||||
x_ptrs = x_base[None, :] + (
|
||||
(idx_tokens_conv - VAL) *
|
||||
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
|
||||
(idx_tokens_conv - VAL < seqlen)[:, None] &
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||
|
||||
conv_states_ptrs_target = conv_states_base + (
|
||||
idx_tokens_conv *
|
||||
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||
< dim)[None, :]
|
||||
tl.debug_barrier()
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
def causal_conv1d_fn(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
bias: Union[torch.Tensor, None],
|
||||
conv_states: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
validate_data=False):
|
||||
"""support varlen + continuous batching when x is 2D tensor
|
||||
x: (dim,cu_seq_len)
|
||||
cu_seq_len = total tokens of all seqs in that batch
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
|
||||
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
|
||||
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
|
||||
]
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
if
|
||||
x = [5, 1, 1, 1] <- continuous batching (batch=4)
|
||||
then
|
||||
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
|
||||
the ending index of the last sequence
|
||||
[length(query_start_loc)-1 == batch]
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
@@ -90,46 +301,144 @@ def causal_conv1d_fn(
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
[single boolean for each sequence in the batch: True or False]
|
||||
bias: (dim,)
|
||||
activation: either None or "silu" or "swish" or True
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
out: same shape as `x`
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
if isinstance(activation, bool) and activation:
|
||||
activation = "silu"
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
# Store original dtype to cast back at the end
|
||||
out = torch.empty_strided(x.size(),
|
||||
x.stride(),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
if cache_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
dim, _ = x.shape
|
||||
_, width = weight.shape
|
||||
|
||||
state_len = width - 1
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
padded_batch = query_start_loc.size(0) - 1
|
||||
stride_x_dim = x.stride(0)
|
||||
stride_x_token = x.stride(1)
|
||||
stride_w_dim = weight.stride(0)
|
||||
stride_w_width = weight.stride(1)
|
||||
stride_istate_seq = 0
|
||||
stride_istate_dim = 0
|
||||
stride_istate_token = 0
|
||||
stride_o_dim = out.stride(0)
|
||||
stride_o_token = out.stride(1)
|
||||
|
||||
num_cache_lines = 0
|
||||
if conv_states is not None:
|
||||
# extensions to support vLLM:
|
||||
# 1. conv_states is used to replaced initial_states
|
||||
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
|
||||
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
||||
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
||||
num_cache_lines = conv_states.size(0)
|
||||
stride_istate_seq = conv_states.stride(0)
|
||||
stride_istate_dim = conv_states.stride(1)
|
||||
stride_istate_token = conv_states.stride(2)
|
||||
|
||||
stride_cache_indices = cache_indices.stride(
|
||||
0) if cache_indices is not None else 0
|
||||
|
||||
if validate_data:
|
||||
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
||||
assert x.dim() == 2
|
||||
assert width in [2, 3, 4]
|
||||
assert query_start_loc is not None
|
||||
assert query_start_loc.dim() == 1
|
||||
assert x.stride(0) == 1 or x.stride(1) == 1
|
||||
if bias is not None:
|
||||
assert bias.dim() == 1
|
||||
assert dim == bias.size(0)
|
||||
if conv_states is not None:
|
||||
assert (num_cache_lines == conv_states.shape[0]
|
||||
and dim == conv_states.shape[1]
|
||||
and conv_states.shape[2] >= width - 1)
|
||||
assert stride_istate_dim == 1
|
||||
if cache_indices is not None:
|
||||
assert cache_indices.dim() == 1
|
||||
assert padded_batch == cache_indices.size(0)
|
||||
if has_initial_state is not None:
|
||||
assert has_initial_state.size() == (padded_batch, )
|
||||
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
|
||||
assert weight.stride(1) == 1
|
||||
assert (dim, width) == weight.shape
|
||||
assert is_channel_last, "Need to run in channel-last layout"
|
||||
|
||||
BLOCK_M = 64
|
||||
seqlens = query_start_loc.diff()
|
||||
seq_blocks = -(-seqlens // BLOCK_M)
|
||||
total_seq_blocks = seq_blocks.sum().item()
|
||||
# tracking which seq-idx the Triton program is handling
|
||||
batch_ptr = torch.repeat_interleave(
|
||||
torch.arange(len(seq_blocks), device=x.device),
|
||||
seq_blocks).to(torch.int32)
|
||||
|
||||
# tracking BLOCK_M-based index in the sequence the Triton program is handling
|
||||
max_blocks = seq_blocks.max().item() if len(seq_blocks) > 0 else 0
|
||||
arange = torch.arange(max_blocks, device=x.device)
|
||||
mask = arange.unsqueeze(0) < seq_blocks.unsqueeze(1)
|
||||
token_chunk_offset_ptr = arange.repeat(len(seq_blocks),
|
||||
1)[mask].to(torch.int32)
|
||||
|
||||
BLOCK_N = 256
|
||||
grid = (total_seq_blocks, triton.cdiv(dim, BLOCK_N))
|
||||
|
||||
with torch.npu.device(x.device.index):
|
||||
_causal_conv1d_fwd_kernel[grid](
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]]
|
||||
if has_initial_state[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
conv_states,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
query_start_loc,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
dim,
|
||||
state_len,
|
||||
num_cache_lines,
|
||||
# stride
|
||||
stride_x_dim,
|
||||
stride_x_token,
|
||||
stride_w_dim,
|
||||
stride_w_width,
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_cache_indices,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
# others
|
||||
pad_slot_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
HAS_INITIAL_STATES=has_initial_state is not None,
|
||||
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# TODO copied from vllm and it needs to be optimized
|
||||
@@ -4,7 +4,7 @@ from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
|
||||
from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn
|
||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
from vllm_ascend.ops.triton.mamba.casual_conv1d import (
|
||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update_npu)
|
||||
|
||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
|
||||
|
||||
Reference in New Issue
Block a user