[Ops] update causal_conv1d_update (#5984)
### What this PR does / why we need it?
Update causal_conv1d_update ops for better perf.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -595,6 +595,8 @@ def causal_conv1d_update_npu(
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||
"""
|
||||
weight = weight.transpose(0, 1).contiguous()
|
||||
conv_state = conv_state.transpose(1, 2).contiguous()
|
||||
if validate_data:
|
||||
assert pad_slot_id is not None
|
||||
assert x.stride(1) == 1
|
||||
@@ -608,40 +610,33 @@ def causal_conv1d_update_npu(
|
||||
unsqueeze = query_start_loc is None and x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
x = x.unsqueeze(-1)
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
if query_start_loc is None:
|
||||
batch, dim, seqlen = x.shape
|
||||
batch, seqlen, dim = x.shape
|
||||
else:
|
||||
assert conv_state_indices is not None
|
||||
batch = conv_state_indices.size(0)
|
||||
dim = x.size(1)
|
||||
seqlen = max_query_len
|
||||
|
||||
_, width = weight.shape
|
||||
num_cache_lines, _, state_len_total = conv_state.size()
|
||||
|
||||
if validate_data:
|
||||
assert dim == weight.size(0)
|
||||
assert conv_state.stride(-2) == 1
|
||||
assert state_len_total >= width - 1
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1
|
||||
width, _ = weight.shape
|
||||
num_cache_lines, state_len_total,_ = conv_state.size()
|
||||
|
||||
# overwrite-on-x strategy same as original
|
||||
out = x
|
||||
|
||||
stride_w_dim, stride_w_width = weight.stride()
|
||||
stride_w_width, stride_w_dim = weight.stride()
|
||||
if query_start_loc is None:
|
||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
|
||||
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
|
||||
else:
|
||||
stride_x_token, stride_x_dim = x.stride()
|
||||
stride_x_seq = 0
|
||||
stride_o_token, stride_o_dim = out.stride()
|
||||
stride_o_seq = 0
|
||||
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
0) if conv_state_indices is not None else 0
|
||||
@@ -657,7 +652,7 @@ def causal_conv1d_update_npu(
|
||||
#keep program count around ~[80..160]
|
||||
# vector core 40
|
||||
# TODO: use driver to get the vector core num
|
||||
CORE_HINT = 40
|
||||
CORE_HINT = 40
|
||||
# channel tile: 512 when dim large (reduce tasks), else 256
|
||||
block_n = 512 if dim >= 512 else 256
|
||||
g = triton.cdiv(dim, block_n)
|
||||
@@ -674,14 +669,13 @@ def causal_conv1d_update_npu(
|
||||
b_tile = 8
|
||||
|
||||
# token chunk based on block_n (32KB UB idea); conservative
|
||||
t_chunk = 20 if block_n == 512 else 48
|
||||
t_chunk = 1 if block_n == 512 else 48
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(batch, META["B_TILE"]),
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
_causal_conv1d_update_kernel_npu_tiled[grid](
|
||||
x,
|
||||
weight,
|
||||
@@ -725,5 +719,5 @@ def causal_conv1d_update_npu(
|
||||
)
|
||||
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
out = out.squeeze(1)
|
||||
return out.to(original_x_dtype)
|
||||
|
||||
Reference in New Issue
Block a user