[FIX] Update _causal_conv1d_update_kernel for Efficient Conv State Handling on NPU (#5322)

Description:

This PR updates the implementation of the Triton operator for deployment
on NPU devices, focusing on optimizing grid size and memory handling
based on NPU limitations.

Design Plan:

Grid Calculation: The grid size is now dynamically calculated by batch
and dim to ensure that the number of programs executed does not exceed
the NPU's vector core capacity. This ensures optimal parallelism without
overloading the hardware.

Data Block Handling: Due to the limited on-chip memory (UB) on Ascend
NPUs, this implementation splits large data into smaller chunks of 32k
or less per block. The kernel performs a for-loop to process the data in
these smaller chunks, minimizing memory usage and avoiding potential
overflows.

Changes Compared to GPU Implementation:

Grid and Block Sizing:

For GPU, the grid and block size were determined based on available
thread counts and memory size. In contrast, the NPU version dynamically
adjusts these parameters using B_TILE and BLOCK_N to optimize for NPU’s
architecture.

Memory Chunking:

The original GPU implementation did not require chunking due to the
higher available memory and processing capacity. For the NPU, data is
divided into smaller chunks (32k or smaller) to comply with memory
constraints on the device. The kernel has been modified to handle this
chunking mechanism inside a loop.

Optimized Thread Usage:

The NPU implementation takes into account the hardware-specific thread
limit (24 threads per vector core), ensuring that the number of active
programs is aligned with the NPU's vector core count, avoiding
over-subscription that would lead to serial processing.

This PR ensures that the operator functions efficiently on Ascend NPU,
considering hardware limitations while maintaining the same
functionality and input parameters as the GPU implementation.


- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
Qi Mao
2025-12-26 09:12:30 +08:00
committed by GitHub
parent 4ce32c1a8d
commit 7372225bcb
2 changed files with 563 additions and 752 deletions

View File

@@ -6,6 +6,8 @@ import torch.nn.functional as F
from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
causal_conv1d_fn)
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
causal_conv1d_update_npu as causal_conv1d_update
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
@@ -156,17 +158,13 @@ def causal_conv1d_fn_pytorch(
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype',
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [False, True])
@pytest.mark.parametrize('has_bias', [False, True])
@pytest.mark.parametrize('seq_len', [[
1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134,
2048, 4096
]])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])
@pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
@pytest.mark.parametrize('extra_state_len', [0, 2])
@pytest.mark.parametrize('width', [2, 3, 4])
@pytest.mark.parametrize('dim', [64, 4160])
@pytest.mark.parametrize('width', [2, 4])
@pytest.mark.parametrize('dim', [4160])
def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
silu_activation, itype, has_initial_state):
@@ -227,4 +225,137 @@ def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
query_start_loc=query_start_loc)
validate_cmp(out, out_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)
def causal_conv1d_update_ref(x,
conv_state,
weight,
bias=None,
activation=None,
cache_seqlens=None):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand(
-1, dim, -1))
x_new = torch.cat([conv_state.gather(2, width_idx), x],
dim=-1).to(weight.dtype)
copy_idx = torch.arange(
seqlen, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx,
state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
groups=dim)[:, :, -seqlen:]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3, 64])
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype):
device = "npu"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)

File diff suppressed because it is too large Load Diff