[Performance] Qwen3-Next: optimize causal_conv1d_fn triton kernel - up to 9% faster (#10680)
This commit is contained in:
@@ -362,6 +362,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
has_initial_state=has_initial_states,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
).transpose(0, 1)[:seq_len]
|
||||
|
||||
key_split_dim = key_dim // attn_tp_size
|
||||
|
||||
@@ -23,6 +23,7 @@ def causal_conv1d_fn(
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
cache_indices_ptr, # conv_state_indices_ptr
|
||||
has_initial_states_ptr,
|
||||
query_start_loc_ptr,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
||||
# Matrix dimensions
|
||||
batch: tl.int32, # actually padded_batch
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.int32, # cu_seqlen
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
@@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||
|
||||
# single-sequence id
|
||||
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
||||
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
||||
idx_seq = tl.program_id(0)
|
||||
chunk_offset = tl.program_id(1)
|
||||
|
||||
# BLOCK_N elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
if idx_seq == pad_slot_id:
|
||||
return
|
||||
@@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
token_offset = BLOCK_M * chunk_offset
|
||||
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||
|
||||
if segment_len <= 0:
|
||||
return
|
||||
|
||||
# base of the sequence
|
||||
x_base = (
|
||||
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
|
||||
@@ -382,12 +382,13 @@ def causal_conv1d_fn(
|
||||
bias: Union[torch.Tensor, None],
|
||||
conv_states: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_cpu: List[int],
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""support varlen + continuous batching when x is 2D tensor
|
||||
|
||||
@@ -413,6 +414,8 @@ def causal_conv1d_fn(
|
||||
[length(query_start_loc)-1 == batch]
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
seq_lens_cpu: (batch) int32
|
||||
The sequence lengths of the sequences in the batch
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
@@ -434,26 +437,7 @@ def causal_conv1d_fn(
|
||||
if isinstance(activation, bool) and activation:
|
||||
activation = "silu"
|
||||
|
||||
args = None
|
||||
out = torch.empty_like(x)
|
||||
if metadata is not None:
|
||||
cu_seqlen = metadata.cu_seqlen
|
||||
nums_dict = metadata.nums_dict
|
||||
# x = metadata.x
|
||||
args = nums_dict
|
||||
batch_ptr = metadata.batch_ptr
|
||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||
else:
|
||||
seqlens = np.diff(query_start_loc.to("cpu"))
|
||||
args = seqlens
|
||||
MAX_NUM_PROGRAMS = 1024
|
||||
|
||||
batch_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
||||
) # tracking which seq-idx the Triton program is handling
|
||||
token_chunk_offset_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
||||
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
|
||||
|
||||
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
||||
dim, cu_seqlen = x.shape
|
||||
@@ -461,7 +445,6 @@ def causal_conv1d_fn(
|
||||
state_len = width - 1
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
padded_batch = query_start_loc.size(0) - 1
|
||||
stride_x_seq = 0
|
||||
stride_x_dim = x.stride(0)
|
||||
stride_x_token = x.stride(1)
|
||||
@@ -501,6 +484,7 @@ def causal_conv1d_fn(
|
||||
assert query_start_loc is not None
|
||||
assert query_start_loc.dim() == 1
|
||||
assert x.stride(0) == 1 or x.stride(1) == 1
|
||||
padded_batch = query_start_loc.size(0) - 1
|
||||
if bias is not None:
|
||||
assert bias.dim() == 1
|
||||
assert dim == bias.size(0)
|
||||
@@ -516,78 +500,14 @@ def causal_conv1d_fn(
|
||||
assert (dim, width) == weight.shape
|
||||
assert is_channel_last, "Need to run in channel-last layout"
|
||||
|
||||
if metadata is None:
|
||||
|
||||
def num_program(META, seqlens):
|
||||
tot = 0
|
||||
|
||||
mlist = []
|
||||
offsetlist = [] # type: ignore
|
||||
|
||||
nums = -(-seqlens // META["BLOCK_M"])
|
||||
|
||||
tot = nums.sum().item()
|
||||
mlist = np.repeat(np.arange(len(nums)), nums)
|
||||
for idx, num in enumerate(nums):
|
||||
offsetlist.extend(
|
||||
range(num)
|
||||
) # chunk-idx if a sequence is split into multiple chunks
|
||||
|
||||
if META["batch_ptr"].nelement() < len(mlist):
|
||||
newlen = len(mlist) + 1
|
||||
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
|
||||
if META["batch_ptr"].nelement() >= len(mlist):
|
||||
META["batch_ptr"][0 : len(mlist)].copy_(
|
||||
torch.from_numpy(np.array(mlist))
|
||||
)
|
||||
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
|
||||
torch.from_numpy(np.array(offsetlist))
|
||||
)
|
||||
|
||||
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
|
||||
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
|
||||
META["x_ptr"].device
|
||||
)
|
||||
return tot
|
||||
|
||||
else:
|
||||
|
||||
def num_program(META, nums_dict):
|
||||
tot = nums_dict[META["BLOCK_M"]]["tot"]
|
||||
|
||||
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
|
||||
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
|
||||
|
||||
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
|
||||
|
||||
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
|
||||
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
|
||||
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
|
||||
"token_chunk_offset_ptr"
|
||||
]
|
||||
else:
|
||||
if META["batch_ptr"].nelement() < mlist_len:
|
||||
newlen = mlist_len + 1
|
||||
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
|
||||
if META["batch_ptr"].nelement() >= mlist_len:
|
||||
META["batch_ptr"][0:mlist_len].copy_(mlist)
|
||||
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
|
||||
return tot
|
||||
|
||||
def grid(META):
|
||||
max_seq_len = max(seq_lens_cpu)
|
||||
return (
|
||||
num_program(META, args),
|
||||
len(seq_lens_cpu), # batch_size
|
||||
(max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
if batch_ptr.device != x.device:
|
||||
batch_ptr = batch_ptr.to(x.device)
|
||||
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
|
||||
|
||||
_causal_conv1d_fwd_kernel[grid](
|
||||
# Pointers to matrices
|
||||
x,
|
||||
@@ -597,11 +517,8 @@ def causal_conv1d_fn(
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
query_start_loc,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
padded_batch,
|
||||
dim,
|
||||
cu_seqlen,
|
||||
num_cache_lines,
|
||||
|
||||
Reference in New Issue
Block a user