[Ops] update causal_conv1d_update (#5984)

### What this PR does / why we need it?
Update causal_conv1d_update ops for better perf.

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2026-01-21 16:33:52 +08:00
committed by GitHub
parent 53bfb38192
commit 2a618d2454

View File

@@ -595,6 +595,8 @@ def causal_conv1d_update_npu(
indices 0 and 3 indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
""" """
weight = weight.transpose(0, 1).contiguous()
conv_state = conv_state.transpose(1, 2).contiguous()
if validate_data: if validate_data:
assert pad_slot_id is not None assert pad_slot_id is not None
assert x.stride(1) == 1 assert x.stride(1) == 1
@@ -608,40 +610,33 @@ def causal_conv1d_update_npu(
unsqueeze = query_start_loc is None and x.dim() == 2 unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze: if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1 # make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1) x = x.unsqueeze(1)
if query_start_loc is None: if query_start_loc is None:
batch, dim, seqlen = x.shape batch, seqlen, dim = x.shape
else: else:
assert conv_state_indices is not None assert conv_state_indices is not None
batch = conv_state_indices.size(0) batch = conv_state_indices.size(0)
dim = x.size(1) dim = x.size(1)
seqlen = max_query_len seqlen = max_query_len
_, width = weight.shape width, _ = weight.shape
num_cache_lines, _, state_len_total = conv_state.size() num_cache_lines, state_len_total,_ = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(-2) == 1
assert state_len_total >= width - 1
assert num_cache_lines >= batch
assert weight.stride(1) == 1
# overwrite-on-x strategy same as original # overwrite-on-x strategy same as original
out = x out = x
stride_w_dim, stride_w_width = weight.stride() stride_w_width, stride_w_dim = weight.stride()
if query_start_loc is None: if query_start_loc is None:
stride_x_seq, stride_x_dim, stride_x_token = x.stride() stride_x_seq, stride_x_token,stride_x_dim = x.stride()
stride_o_seq, stride_o_dim, stride_o_token = out.stride() stride_o_seq, stride_o_token, stride_o_dim = out.stride()
else: else:
stride_x_token, stride_x_dim = x.stride() stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0 stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride() stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0 stride_o_seq = 0
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
) )
stride_state_indices = conv_state_indices.stride( stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0 0) if conv_state_indices is not None else 0
@@ -657,7 +652,7 @@ def causal_conv1d_update_npu(
#keep program count around ~[80..160] #keep program count around ~[80..160]
# vector core 40 # vector core 40
# TODO: use driver to get the vector core num # TODO: use driver to get the vector core num
CORE_HINT = 40 CORE_HINT = 40
# channel tile: 512 when dim large (reduce tasks), else 256 # channel tile: 512 when dim large (reduce tasks), else 256
block_n = 512 if dim >= 512 else 256 block_n = 512 if dim >= 512 else 256
g = triton.cdiv(dim, block_n) g = triton.cdiv(dim, block_n)
@@ -674,14 +669,13 @@ def causal_conv1d_update_npu(
b_tile = 8 b_tile = 8
# token chunk based on block_n (32KB UB idea); conservative # token chunk based on block_n (32KB UB idea); conservative
t_chunk = 20 if block_n == 512 else 48 t_chunk = 1 if block_n == 512 else 48
def grid(META): def grid(META):
return ( return (
triton.cdiv(batch, META["B_TILE"]), triton.cdiv(batch, META["B_TILE"]),
triton.cdiv(dim, META["BLOCK_N"]), triton.cdiv(dim, META["BLOCK_N"]),
) )
_causal_conv1d_update_kernel_npu_tiled[grid]( _causal_conv1d_update_kernel_npu_tiled[grid](
x, x,
weight, weight,
@@ -725,5 +719,5 @@ def causal_conv1d_update_npu(
) )
if unsqueeze: if unsqueeze:
out = out.squeeze(-1) out = out.squeeze(1)
return out.to(original_x_dtype) return out.to(original_x_dtype)