[Performance] Qwen3-Next: optimize causal_conv1d_fn triton kernel - up to 9% faster (#10680)
This commit is contained in:
@@ -362,6 +362,7 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
has_initial_state=has_initial_states,
|
has_initial_state=has_initial_states,
|
||||||
cache_indices=cache_indices,
|
cache_indices=cache_indices,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
|
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||||
).transpose(0, 1)[:seq_len]
|
).transpose(0, 1)[:seq_len]
|
||||||
|
|
||||||
key_split_dim = key_dim // attn_tp_size
|
key_split_dim = key_dim // attn_tp_size
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def causal_conv1d_fn(
|
|||||||
conv_states: Optional[torch.Tensor] = None,
|
conv_states: Optional[torch.Tensor] = None,
|
||||||
activation: Optional[str] = "silu",
|
activation: Optional[str] = "silu",
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
cache_indices_ptr, # conv_state_indices_ptr
|
cache_indices_ptr, # conv_state_indices_ptr
|
||||||
has_initial_states_ptr,
|
has_initial_states_ptr,
|
||||||
query_start_loc_ptr,
|
query_start_loc_ptr,
|
||||||
batch_ptr,
|
|
||||||
token_chunk_offset_ptr,
|
|
||||||
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
batch: tl.int32, # actually padded_batch
|
|
||||||
dim: tl.constexpr,
|
dim: tl.constexpr,
|
||||||
seqlen: tl.int32, # cu_seqlen
|
seqlen: tl.int32, # cu_seqlen
|
||||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||||
@@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||||
|
|
||||||
# single-sequence id
|
# single-sequence id
|
||||||
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
idx_seq = tl.program_id(0)
|
||||||
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
chunk_offset = tl.program_id(1)
|
||||||
|
|
||||||
# BLOCK_N elements along the feature-dimension (channel)
|
# BLOCK_N elements along the feature-dimension (channel)
|
||||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
if idx_seq == pad_slot_id:
|
if idx_seq == pad_slot_id:
|
||||||
return
|
return
|
||||||
@@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
token_offset = BLOCK_M * chunk_offset
|
token_offset = BLOCK_M * chunk_offset
|
||||||
segment_len = min(BLOCK_M, seqlen - token_offset)
|
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||||
|
|
||||||
|
if segment_len <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
# base of the sequence
|
# base of the sequence
|
||||||
x_base = (
|
x_base = (
|
||||||
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
|
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
|
||||||
@@ -382,12 +382,13 @@ def causal_conv1d_fn(
|
|||||||
bias: Union[torch.Tensor, None],
|
bias: Union[torch.Tensor, None],
|
||||||
conv_states: torch.Tensor,
|
conv_states: torch.Tensor,
|
||||||
query_start_loc: torch.Tensor,
|
query_start_loc: torch.Tensor,
|
||||||
|
seq_lens_cpu: List[int],
|
||||||
cache_indices: Optional[torch.Tensor] = None,
|
cache_indices: Optional[torch.Tensor] = None,
|
||||||
has_initial_state: Optional[torch.Tensor] = None,
|
has_initial_state: Optional[torch.Tensor] = None,
|
||||||
activation: Optional[str] = "silu",
|
activation: Optional[str] = "silu",
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
metadata=None,
|
|
||||||
validate_data=False,
|
validate_data=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""support varlen + continuous batching when x is 2D tensor
|
"""support varlen + continuous batching when x is 2D tensor
|
||||||
|
|
||||||
@@ -413,6 +414,8 @@ def causal_conv1d_fn(
|
|||||||
[length(query_start_loc)-1 == batch]
|
[length(query_start_loc)-1 == batch]
|
||||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||||
x.shape=(dim,17)
|
x.shape=(dim,17)
|
||||||
|
seq_lens_cpu: (batch) int32
|
||||||
|
The sequence lengths of the sequences in the batch
|
||||||
cache_indices: (batch) int32
|
cache_indices: (batch) int32
|
||||||
indicates the corresponding state index,
|
indicates the corresponding state index,
|
||||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||||
@@ -434,26 +437,7 @@ def causal_conv1d_fn(
|
|||||||
if isinstance(activation, bool) and activation:
|
if isinstance(activation, bool) and activation:
|
||||||
activation = "silu"
|
activation = "silu"
|
||||||
|
|
||||||
args = None
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
if metadata is not None:
|
|
||||||
cu_seqlen = metadata.cu_seqlen
|
|
||||||
nums_dict = metadata.nums_dict
|
|
||||||
# x = metadata.x
|
|
||||||
args = nums_dict
|
|
||||||
batch_ptr = metadata.batch_ptr
|
|
||||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
|
||||||
else:
|
|
||||||
seqlens = np.diff(query_start_loc.to("cpu"))
|
|
||||||
args = seqlens
|
|
||||||
MAX_NUM_PROGRAMS = 1024
|
|
||||||
|
|
||||||
batch_ptr = torch.full(
|
|
||||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
|
||||||
) # tracking which seq-idx the Triton program is handling
|
|
||||||
token_chunk_offset_ptr = torch.full(
|
|
||||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
|
||||||
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
|
|
||||||
|
|
||||||
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
||||||
dim, cu_seqlen = x.shape
|
dim, cu_seqlen = x.shape
|
||||||
@@ -461,7 +445,6 @@ def causal_conv1d_fn(
|
|||||||
state_len = width - 1
|
state_len = width - 1
|
||||||
np2_statelen = triton.next_power_of_2(state_len)
|
np2_statelen = triton.next_power_of_2(state_len)
|
||||||
|
|
||||||
padded_batch = query_start_loc.size(0) - 1
|
|
||||||
stride_x_seq = 0
|
stride_x_seq = 0
|
||||||
stride_x_dim = x.stride(0)
|
stride_x_dim = x.stride(0)
|
||||||
stride_x_token = x.stride(1)
|
stride_x_token = x.stride(1)
|
||||||
@@ -501,6 +484,7 @@ def causal_conv1d_fn(
|
|||||||
assert query_start_loc is not None
|
assert query_start_loc is not None
|
||||||
assert query_start_loc.dim() == 1
|
assert query_start_loc.dim() == 1
|
||||||
assert x.stride(0) == 1 or x.stride(1) == 1
|
assert x.stride(0) == 1 or x.stride(1) == 1
|
||||||
|
padded_batch = query_start_loc.size(0) - 1
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
assert bias.dim() == 1
|
assert bias.dim() == 1
|
||||||
assert dim == bias.size(0)
|
assert dim == bias.size(0)
|
||||||
@@ -516,78 +500,14 @@ def causal_conv1d_fn(
|
|||||||
assert (dim, width) == weight.shape
|
assert (dim, width) == weight.shape
|
||||||
assert is_channel_last, "Need to run in channel-last layout"
|
assert is_channel_last, "Need to run in channel-last layout"
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
|
|
||||||
def num_program(META, seqlens):
|
|
||||||
tot = 0
|
|
||||||
|
|
||||||
mlist = []
|
|
||||||
offsetlist = [] # type: ignore
|
|
||||||
|
|
||||||
nums = -(-seqlens // META["BLOCK_M"])
|
|
||||||
|
|
||||||
tot = nums.sum().item()
|
|
||||||
mlist = np.repeat(np.arange(len(nums)), nums)
|
|
||||||
for idx, num in enumerate(nums):
|
|
||||||
offsetlist.extend(
|
|
||||||
range(num)
|
|
||||||
) # chunk-idx if a sequence is split into multiple chunks
|
|
||||||
|
|
||||||
if META["batch_ptr"].nelement() < len(mlist):
|
|
||||||
newlen = len(mlist) + 1
|
|
||||||
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
|
||||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
|
||||||
|
|
||||||
if META["batch_ptr"].nelement() >= len(mlist):
|
|
||||||
META["batch_ptr"][0 : len(mlist)].copy_(
|
|
||||||
torch.from_numpy(np.array(mlist))
|
|
||||||
)
|
|
||||||
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
|
|
||||||
torch.from_numpy(np.array(offsetlist))
|
|
||||||
)
|
|
||||||
|
|
||||||
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
|
|
||||||
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
|
|
||||||
META["x_ptr"].device
|
|
||||||
)
|
|
||||||
return tot
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def num_program(META, nums_dict):
|
|
||||||
tot = nums_dict[META["BLOCK_M"]]["tot"]
|
|
||||||
|
|
||||||
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
|
|
||||||
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
|
|
||||||
|
|
||||||
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
|
|
||||||
|
|
||||||
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
|
|
||||||
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
|
|
||||||
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
|
|
||||||
"token_chunk_offset_ptr"
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
if META["batch_ptr"].nelement() < mlist_len:
|
|
||||||
newlen = mlist_len + 1
|
|
||||||
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
|
||||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
|
||||||
|
|
||||||
if META["batch_ptr"].nelement() >= mlist_len:
|
|
||||||
META["batch_ptr"][0:mlist_len].copy_(mlist)
|
|
||||||
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
|
|
||||||
return tot
|
|
||||||
|
|
||||||
def grid(META):
|
def grid(META):
|
||||||
|
max_seq_len = max(seq_lens_cpu)
|
||||||
return (
|
return (
|
||||||
num_program(META, args),
|
len(seq_lens_cpu), # batch_size
|
||||||
|
(max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
|
||||||
triton.cdiv(dim, META["BLOCK_N"]),
|
triton.cdiv(dim, META["BLOCK_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_ptr.device != x.device:
|
|
||||||
batch_ptr = batch_ptr.to(x.device)
|
|
||||||
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
|
|
||||||
|
|
||||||
_causal_conv1d_fwd_kernel[grid](
|
_causal_conv1d_fwd_kernel[grid](
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
x,
|
x,
|
||||||
@@ -597,11 +517,8 @@ def causal_conv1d_fn(
|
|||||||
cache_indices,
|
cache_indices,
|
||||||
has_initial_state,
|
has_initial_state,
|
||||||
query_start_loc,
|
query_start_loc,
|
||||||
batch_ptr,
|
|
||||||
token_chunk_offset_ptr,
|
|
||||||
out,
|
out,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
padded_batch,
|
|
||||||
dim,
|
dim,
|
||||||
cu_seqlen,
|
cu_seqlen,
|
||||||
num_cache_lines,
|
num_cache_lines,
|
||||||
|
|||||||
Reference in New Issue
Block a user