[Ops] update causal_conv1d_update (#5984)
### What this PR does / why we need it?
Update causal_conv1d_update ops for better perf.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -595,6 +595,8 @@ def causal_conv1d_update_npu(
|
|||||||
indices 0 and 3
|
indices 0 and 3
|
||||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||||
"""
|
"""
|
||||||
|
weight = weight.transpose(0, 1).contiguous()
|
||||||
|
conv_state = conv_state.transpose(1, 2).contiguous()
|
||||||
if validate_data:
|
if validate_data:
|
||||||
assert pad_slot_id is not None
|
assert pad_slot_id is not None
|
||||||
assert x.stride(1) == 1
|
assert x.stride(1) == 1
|
||||||
@@ -608,40 +610,33 @@ def causal_conv1d_update_npu(
|
|||||||
unsqueeze = query_start_loc is None and x.dim() == 2
|
unsqueeze = query_start_loc is None and x.dim() == 2
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
# make it (batch, dim, seqlen) with seqlen == 1
|
# make it (batch, dim, seqlen) with seqlen == 1
|
||||||
x = x.unsqueeze(-1)
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
if query_start_loc is None:
|
if query_start_loc is None:
|
||||||
batch, dim, seqlen = x.shape
|
batch, seqlen, dim = x.shape
|
||||||
else:
|
else:
|
||||||
assert conv_state_indices is not None
|
assert conv_state_indices is not None
|
||||||
batch = conv_state_indices.size(0)
|
batch = conv_state_indices.size(0)
|
||||||
dim = x.size(1)
|
dim = x.size(1)
|
||||||
seqlen = max_query_len
|
seqlen = max_query_len
|
||||||
|
|
||||||
_, width = weight.shape
|
width, _ = weight.shape
|
||||||
num_cache_lines, _, state_len_total = conv_state.size()
|
num_cache_lines, state_len_total,_ = conv_state.size()
|
||||||
|
|
||||||
if validate_data:
|
|
||||||
assert dim == weight.size(0)
|
|
||||||
assert conv_state.stride(-2) == 1
|
|
||||||
assert state_len_total >= width - 1
|
|
||||||
assert num_cache_lines >= batch
|
|
||||||
assert weight.stride(1) == 1
|
|
||||||
|
|
||||||
# overwrite-on-x strategy same as original
|
# overwrite-on-x strategy same as original
|
||||||
out = x
|
out = x
|
||||||
|
|
||||||
stride_w_dim, stride_w_width = weight.stride()
|
stride_w_width, stride_w_dim = weight.stride()
|
||||||
if query_start_loc is None:
|
if query_start_loc is None:
|
||||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
|
stride_x_seq, stride_x_token,stride_x_dim = x.stride()
|
||||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
stride_o_seq, stride_o_token, stride_o_dim = out.stride()
|
||||||
else:
|
else:
|
||||||
stride_x_token, stride_x_dim = x.stride()
|
stride_x_token, stride_x_dim = x.stride()
|
||||||
stride_x_seq = 0
|
stride_x_seq = 0
|
||||||
stride_o_token, stride_o_dim = out.stride()
|
stride_o_token, stride_o_dim = out.stride()
|
||||||
stride_o_seq = 0
|
stride_o_seq = 0
|
||||||
|
|
||||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride(
|
||||||
)
|
)
|
||||||
stride_state_indices = conv_state_indices.stride(
|
stride_state_indices = conv_state_indices.stride(
|
||||||
0) if conv_state_indices is not None else 0
|
0) if conv_state_indices is not None else 0
|
||||||
@@ -657,7 +652,7 @@ def causal_conv1d_update_npu(
|
|||||||
#keep program count around ~[80..160]
|
#keep program count around ~[80..160]
|
||||||
# vector core 40
|
# vector core 40
|
||||||
# TODO: use driver to get the vector core num
|
# TODO: use driver to get the vector core num
|
||||||
CORE_HINT = 40
|
CORE_HINT = 40
|
||||||
# channel tile: 512 when dim large (reduce tasks), else 256
|
# channel tile: 512 when dim large (reduce tasks), else 256
|
||||||
block_n = 512 if dim >= 512 else 256
|
block_n = 512 if dim >= 512 else 256
|
||||||
g = triton.cdiv(dim, block_n)
|
g = triton.cdiv(dim, block_n)
|
||||||
@@ -674,14 +669,13 @@ def causal_conv1d_update_npu(
|
|||||||
b_tile = 8
|
b_tile = 8
|
||||||
|
|
||||||
# token chunk based on block_n (32KB UB idea); conservative
|
# token chunk based on block_n (32KB UB idea); conservative
|
||||||
t_chunk = 20 if block_n == 512 else 48
|
t_chunk = 1 if block_n == 512 else 48
|
||||||
|
|
||||||
def grid(META):
|
def grid(META):
|
||||||
return (
|
return (
|
||||||
triton.cdiv(batch, META["B_TILE"]),
|
triton.cdiv(batch, META["B_TILE"]),
|
||||||
triton.cdiv(dim, META["BLOCK_N"]),
|
triton.cdiv(dim, META["BLOCK_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
_causal_conv1d_update_kernel_npu_tiled[grid](
|
_causal_conv1d_update_kernel_npu_tiled[grid](
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
@@ -725,5 +719,5 @@ def causal_conv1d_update_npu(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
out = out.squeeze(-1)
|
out = out.squeeze(1)
|
||||||
return out.to(original_x_dtype)
|
return out.to(original_x_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user