[Ops] update causal_conv1d_update (#5984)

### What this PR does / why we need it?
Update causal_conv1d_update ops for better perf.

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2026-01-21 16:33:52 +08:00
committed by GitHub
parent 53bfb38192
commit 2a618d2454

View File

@@ -595,6 +595,8 @@ def causal_conv1d_update_npu(
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
weight = weight.transpose(0, 1).contiguous()
conv_state = conv_state.transpose(1, 2).contiguous()
if validate_data:
assert pad_slot_id is not None
assert x.stride(1) == 1
@@ -608,40 +610,33 @@ def causal_conv1d_update_npu(
unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
x = x.unsqueeze(1)
if query_start_loc is None:
batch, dim, seqlen = x.shape
batch, seqlen, dim = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)
seqlen = max_query_len
_, width = weight.shape
num_cache_lines, _, state_len_total = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(-2) == 1
assert state_len_total >= width - 1
assert num_cache_lines >= batch
assert weight.stride(1) == 1
width, _ = weight.shape
num_cache_lines, state_len_total,_ = conv_state.size()
# overwrite-on-x strategy same as original
out = x
stride_w_dim, stride_w_width = weight.stride()
stride_w_width, stride_w_dim = weight.stride()
if query_start_loc is None:
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
else:
stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
)
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
@@ -657,7 +652,7 @@ def causal_conv1d_update_npu(
#keep program count around ~[80..160]
# vector core 40
# TODO: use driver to get the vector core num
CORE_HINT = 40
CORE_HINT = 40
# channel tile: 512 when dim large (reduce tasks), else 256
block_n = 512 if dim >= 512 else 256
g = triton.cdiv(dim, block_n)
@@ -674,14 +669,13 @@ def causal_conv1d_update_npu(
b_tile = 8
# token chunk based on block_n (32KB UB idea); conservative
t_chunk = 20 if block_n == 512 else 48
t_chunk = 1 if block_n == 512 else 48
def grid(META):
return (
triton.cdiv(batch, META["B_TILE"]),
triton.cdiv(dim, META["BLOCK_N"]),
)
_causal_conv1d_update_kernel_npu_tiled[grid](
x,
weight,
@@ -725,5 +719,5 @@ def causal_conv1d_update_npu(
)
if unsqueeze:
out = out.squeeze(-1)
out = out.squeeze(1)
return out.to(original_x_dtype)