### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -7,24 +7,23 @@
|
||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
initial_states: torch.Tensor | None = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
final_states_out: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
@@ -42,19 +41,14 @@ def causal_conv1d_ref(
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
@@ -66,13 +60,13 @@ def causal_conv1d_ref(
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
metadata: Optional[Any] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
conv_states: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
query_start_loc: torch.Tensor | None = None,
|
||||
metadata: Any | None = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
@@ -126,10 +120,10 @@ def causal_conv1d_fn(
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]][..., :(
|
||||
width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
|
||||
if has_initial_state[i] else None))
|
||||
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None,
|
||||
)
|
||||
)
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
@@ -137,54 +131,50 @@ def causal_conv1d_fn(
|
||||
|
||||
@triton.jit
|
||||
def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# Pointers
|
||||
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr, # (num_cache_lines, dim, state_len)
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
block_idx_last_scheduled_token, # (batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
o_ptr, # same shape as x_ptr
|
||||
batch: tl.int32,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
||||
state_len: tl.constexpr, # effective state_len computed in wrapper
|
||||
num_cache_lines: tl.constexpr,
|
||||
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
|
||||
# Meta
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr, # <= 6
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
|
||||
# tiling
|
||||
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
|
||||
B_TILE: tl.constexpr, # batch tile
|
||||
T_CHUNK: tl.constexpr, # token chunk for state update
|
||||
# Pointers
|
||||
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr, # (num_cache_lines, dim, state_len)
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
block_idx_last_scheduled_token, # (batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
o_ptr, # same shape as x_ptr
|
||||
batch: tl.int32,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
|
||||
state_len: tl.constexpr, # effective state_len computed in wrapper
|
||||
num_cache_lines: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr, # <= 6
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
# tiling
|
||||
BLOCK_N: tl.constexpr, # channel tile (C_TILE)
|
||||
B_TILE: tl.constexpr, # batch tile
|
||||
T_CHUNK: tl.constexpr, # token chunk for state update
|
||||
):
|
||||
# program ids
|
||||
pid_b = tl.program_id(0) # batch-tile id
|
||||
@@ -197,37 +187,30 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# preload weights once per program (shared by B_TILE sequences)
|
||||
w_base = w_ptr + idx_feats * stride_w_dim
|
||||
# define to avoid "undefined" in branches
|
||||
w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
if KERNEL_WIDTH >= 1:
|
||||
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 5:
|
||||
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
if KERNEL_WIDTH >= 6:
|
||||
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32)
|
||||
|
||||
# bias vector once per program
|
||||
if HAS_BIAS:
|
||||
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w,
|
||||
other=0.0).to(tl.float32)
|
||||
acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32)
|
||||
else:
|
||||
acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
|
||||
# token index vector for chunked copy
|
||||
tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK]
|
||||
@@ -241,36 +224,26 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# APC mapping (optional)
|
||||
# -------------------------
|
||||
if IS_APC_ENABLED:
|
||||
conv_state_init = tl.load(initial_state_idx + b,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int32)
|
||||
current_last_index = tl.load(block_idx_last_scheduled_token + b,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int32)
|
||||
conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32)
|
||||
current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32)
|
||||
else:
|
||||
conv_state_init = tl.full((), 0, tl.int32)
|
||||
current_last_index = tl.full((), 0, tl.int32)
|
||||
|
||||
# input cache line
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
b * stride_state_indices +
|
||||
conv_state_init,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
conv_states_input_coord = tl.load(
|
||||
conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0
|
||||
).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT:
|
||||
lane_active = lane_active & (conv_states_input_coord
|
||||
!= pad_slot_id)
|
||||
lane_active = lane_active & (conv_states_input_coord != pad_slot_id)
|
||||
|
||||
# -------------------------
|
||||
# varlen (optional): revise seqlen_run and state_len_run like original kernel does
|
||||
# -------------------------
|
||||
if IS_VARLEN:
|
||||
qs = tl.load(query_start_loc_ptr + b, mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
qe = tl.load(query_start_loc_ptr + (b + 1),
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64)
|
||||
qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64)
|
||||
seqlen_run = (qe - qs).to(tl.int32)
|
||||
# revise effective state_len for shorter sequences (same formula as original)
|
||||
state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32)
|
||||
@@ -289,9 +262,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# spec decoding offset (optional)
|
||||
# -------------------------
|
||||
if IS_SPEC_DECODING:
|
||||
conv_state_token_offset = (
|
||||
tl.load(num_accepted_tokens_ptr + b, mask=lane_active,
|
||||
other=1).to(tl.int64) - 1)
|
||||
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1
|
||||
shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode
|
||||
else:
|
||||
conv_state_token_offset = tl.full((), 0, tl.int64)
|
||||
@@ -300,37 +271,37 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# -------------------------
|
||||
# STEP 1: read initial history cols BEFORE state update (out==x safe)
|
||||
# -------------------------
|
||||
conv_states_base = (conv_state_ptr +
|
||||
conv_states_input_coord * stride_conv_state_seq +
|
||||
idx_feats * stride_conv_state_dim)
|
||||
conv_states_base = (
|
||||
conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
||||
)
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
|
||||
# define history vectors as zeros then load conditionally
|
||||
col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16)
|
||||
col0 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col1 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col2 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col3 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
col4 = tl.zeros((BLOCK_N,), dtype=tl.float16)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 5:
|
||||
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
if KERNEL_WIDTH >= 6:
|
||||
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to(
|
||||
tl.float16
|
||||
)
|
||||
|
||||
# -------------------------
|
||||
# STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block)
|
||||
@@ -340,29 +311,25 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run]
|
||||
# -------------------------
|
||||
# output cache line
|
||||
conv_states_offset = tl.load(conv_state_indices_ptr +
|
||||
b * stride_state_indices +
|
||||
current_last_index,
|
||||
mask=lane_active,
|
||||
other=0).to(tl.int64)
|
||||
conv_states_offset = tl.load(
|
||||
conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0
|
||||
).to(tl.int64)
|
||||
|
||||
use_shift = (seqlen_run < state_len_run)
|
||||
use_tail = (seqlen_run >= state_len_run)
|
||||
use_shift = seqlen_run < state_len_run
|
||||
use_tail = seqlen_run >= state_len_run
|
||||
|
||||
zero_i32 = tl.full((), 0, tl.int32)
|
||||
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run),
|
||||
zero_i32).to(tl.int32)
|
||||
tail_start = tl.where(use_tail, (seqlen_run - state_len_run),
|
||||
zero_i32).to(tl.int32)
|
||||
keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32)
|
||||
tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32)
|
||||
|
||||
# base pointers
|
||||
state_src_base = (conv_state_ptr +
|
||||
conv_states_input_coord * stride_conv_state_seq +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
idx_feats * stride_conv_state_dim)
|
||||
state_dst_base = (conv_state_ptr +
|
||||
conv_states_offset * stride_conv_state_seq +
|
||||
idx_feats * stride_conv_state_dim)
|
||||
state_src_base = (
|
||||
conv_state_ptr
|
||||
+ conv_states_input_coord * stride_conv_state_seq
|
||||
+ conv_state_token_offset * stride_conv_state_tok
|
||||
+ idx_feats * stride_conv_state_dim
|
||||
)
|
||||
state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim
|
||||
|
||||
x_base = x_ptr + x_offset + idx_feats * stride_x_dim
|
||||
|
||||
@@ -370,16 +337,16 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK):
|
||||
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
||||
src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK]
|
||||
m_tok = use_shift & (dst_tok < keep_shift) & (
|
||||
src_tok < state_len_run) & (dst_tok < state_len_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
|
||||
conv_states_input_coord
|
||||
< num_cache_lines) & (conv_states_offset < num_cache_lines)
|
||||
m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run)
|
||||
m = (
|
||||
(lane_active & m_tok)[:, None]
|
||||
& mask_w[None, :]
|
||||
& (conv_states_input_coord < num_cache_lines)
|
||||
& (conv_states_offset < num_cache_lines)
|
||||
)
|
||||
|
||||
src_ptrs = state_src_base[
|
||||
None, :] + src_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[
|
||||
None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
vals = tl.load(src_ptrs, mask=m, other=0.0)
|
||||
tl.store(dst_ptrs, vals, mask=m)
|
||||
|
||||
@@ -387,14 +354,11 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
for t0 in tl.static_range(0, seqlen, T_CHUNK):
|
||||
x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
||||
dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK]
|
||||
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok
|
||||
< state_len_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
|
||||
conv_states_offset < num_cache_lines)
|
||||
m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
||||
|
||||
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
||||
dst_ptrs = state_dst_base[
|
||||
None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
||||
tl.store(dst_ptrs, x_vals, mask=m)
|
||||
|
||||
@@ -403,12 +367,10 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK]
|
||||
x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK]
|
||||
m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (
|
||||
conv_states_offset < num_cache_lines)
|
||||
m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines)
|
||||
|
||||
x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token
|
||||
dst_ptrs = state_dst_base[
|
||||
None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok
|
||||
x_vals = tl.load(x_ptrs, mask=m, other=0.0)
|
||||
tl.store(dst_ptrs, x_vals, mask=m)
|
||||
|
||||
@@ -433,17 +395,13 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
if KERNEL_WIDTH == 1:
|
||||
# only x[t] * w0
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
matrix_w = w_col0
|
||||
elif KERNEL_WIDTH == 2:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 3:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -451,9 +409,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 4:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -464,9 +420,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 5:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -480,9 +434,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 4:
|
||||
matrix_w = w_col4
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
elif KERNEL_WIDTH == 6:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
@@ -499,9 +451,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
elif j == 5:
|
||||
matrix_w = w_col5
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token
|
||||
matrix_x = tl.load(x_ptrs_1d,
|
||||
mask=lane_active & mask_w,
|
||||
other=0.0).to(tl.float16)
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16)
|
||||
|
||||
acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N]
|
||||
|
||||
@@ -606,7 +556,7 @@ def causal_conv1d_update_npu(
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
if query_start_loc is None:
|
||||
batch, seqlen, dim = x.shape
|
||||
batch, seqlen, dim = x.shape
|
||||
else:
|
||||
assert conv_state_indices is not None
|
||||
batch = conv_state_indices.size(0)
|
||||
@@ -614,14 +564,14 @@ def causal_conv1d_update_npu(
|
||||
seqlen = max_query_len
|
||||
|
||||
width, _ = weight.shape
|
||||
num_cache_lines, state_len_total,_ = conv_state.size()
|
||||
num_cache_lines, state_len_total, _ = conv_state.size()
|
||||
|
||||
# overwrite-on-x strategy same as original
|
||||
out = x
|
||||
|
||||
stride_w_width, stride_w_dim = weight.stride()
|
||||
if query_start_loc is None:
|
||||
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
|
||||
stride_x_seq, stride_x_token, stride_x_dim = x.stride()
|
||||
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
|
||||
else:
|
||||
stride_x_token, stride_x_dim = x.stride()
|
||||
@@ -629,10 +579,8 @@ def causal_conv1d_update_npu(
|
||||
stride_o_token, stride_o_dim = out.stride()
|
||||
stride_o_seq = 0
|
||||
|
||||
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
0) if conv_state_indices is not None else 0
|
||||
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride()
|
||||
stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
||||
|
||||
# effective state_len exactly as original
|
||||
if num_accepted_tokens is not None:
|
||||
@@ -642,10 +590,10 @@ def causal_conv1d_update_npu(
|
||||
np2_statelen = triton.next_power_of_2(eff_state_len)
|
||||
|
||||
# -------- tiling heuristic--------
|
||||
#keep program count around ~[80..160]
|
||||
# keep program count around ~[80..160]
|
||||
# vector core 40
|
||||
# TODO: use driver to get the vector core num
|
||||
CORE_HINT = 40
|
||||
CORE_HINT = 40
|
||||
# channel tile: 512 when dim large (reduce tasks), else 256
|
||||
block_n = 512 if dim >= 512 else 256
|
||||
g = triton.cdiv(dim, block_n)
|
||||
@@ -669,6 +617,7 @@ def causal_conv1d_update_npu(
|
||||
triton.cdiv(batch, META["B_TILE"]),
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
_causal_conv1d_update_kernel_npu_tiled[grid](
|
||||
x,
|
||||
weight,
|
||||
|
||||
Reference in New Issue
Block a user