[qwen3 next ]add ascend c casual_conv1d_fn (#6661)

### What this PR does / why we need it?
add ascend c casual_conv1d_fn

- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
ZT-AIA
2026-03-09 23:29:49 +08:00
committed by GitHub
parent 48b624e4cc
commit ee5347e824
26 changed files with 2504 additions and 14 deletions

View File

@@ -8,7 +8,7 @@ from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
causal_conv1d_fn)
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
causal_conv1d_update_npu as causal_conv1d_update
from vllm_ascend.utils import enable_custom_op
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
@@ -157,6 +157,90 @@ def causal_conv1d_fn_pytorch(
return out_ref_tensor
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])
@pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
@pytest.mark.parametrize('extra_state_len', [0, 2])
@pytest.mark.parametrize('width', [4])
@pytest.mark.parametrize('dim', [2048])
def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
silu_activation, itype, has_initial_state):
torch.random.manual_seed(0)
enable_custom_op()
device = "npu"
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
state_len = width - 1 + extra_state_len
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
weight = torch.randn(dim, width, device=device, dtype=itype)#
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
device=device,
dtype=torch.int32),
dim=0).to(dtype=torch.int32)
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
device=device,
dtype=torch.bool)
activation = None if not silu_activation else "silu"
if has_initial_state:
conv_states = torch.randn((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.randn(
(num_seq, state_len, dim), device=device,
dtype=itype).transpose(-1, -2).copy_(conv_states)
else:
conv_states = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype)
else:
bias = None
out_ref = causal_conv1d_fn_pytorch(
x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states_ref,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
# out = causal_conv1d_fn(x,
# weight,
# bias=bias,
# activation=activation,
# conv_states=conv_states,
# has_initial_state=has_initial_state_tensor,
# cache_indices=cache_indices,
# query_start_loc=query_start_loc)
x_origin=x.transpose(-1, -2)
weight_origin=weight.transpose(-1, -2)
conv_states_origin=conv_states.transpose(-1, -2)
out = torch.ops._C_ascend.causal_conv1d_fn(
x_origin,
weight_origin,
bias,
activation=activation,
conv_state=conv_states_origin,
has_initial_state=has_initial_state_tensor,
non_spec_state_indices_tensor=cache_indices,
non_spec_query_start_loc=query_start_loc,
pad_slot_id=PAD_SLOT_ID,
).transpose(-1, -2)
validate_cmp(out, out_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])