[qwen3 next ]add ascend c casual_conv1d_fn (#6661)

### What this PR does / why we need it?
add ascend c casual_conv1d_fn

- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
ZT-AIA
2026-03-09 23:29:49 +08:00
committed by GitHub
parent 48b624e4cc
commit ee5347e824
26 changed files with 2504 additions and 14 deletions

View File

@@ -92,10 +92,15 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("num_speculative_tokens", [1])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.skip("Skip this CI.")
def test_qwen3_next_mtp_correctness_tp4(model_name: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"Hello, my name is",
"The president of the United States is",
"The capital of France is",

View File

@@ -8,7 +8,7 @@ from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
causal_conv1d_fn)
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
causal_conv1d_update_npu as causal_conv1d_update
from vllm_ascend.utils import enable_custom_op
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
@@ -157,6 +157,90 @@ def causal_conv1d_fn_pytorch(
return out_ref_tensor
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])
@pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
@pytest.mark.parametrize('extra_state_len', [0, 2])
@pytest.mark.parametrize('width', [4])
@pytest.mark.parametrize('dim', [2048])
def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
silu_activation, itype, has_initial_state):
torch.random.manual_seed(0)
enable_custom_op()
device = "npu"
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
state_len = width - 1 + extra_state_len
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
weight = torch.randn(dim, width, device=device, dtype=itype)#
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
device=device,
dtype=torch.int32),
dim=0).to(dtype=torch.int32)
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
device=device,
dtype=torch.bool)
activation = None if not silu_activation else "silu"
if has_initial_state:
conv_states = torch.randn((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.randn(
(num_seq, state_len, dim), device=device,
dtype=itype).transpose(-1, -2).copy_(conv_states)
else:
conv_states = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype)
else:
bias = None
out_ref = causal_conv1d_fn_pytorch(
x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states_ref,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
# out = causal_conv1d_fn(x,
# weight,
# bias=bias,
# activation=activation,
# conv_states=conv_states,
# has_initial_state=has_initial_state_tensor,
# cache_indices=cache_indices,
# query_start_loc=query_start_loc)
x_origin=x.transpose(-1, -2)
weight_origin=weight.transpose(-1, -2)
conv_states_origin=conv_states.transpose(-1, -2)
out = torch.ops._C_ascend.causal_conv1d_fn(
x_origin,
weight_origin,
bias,
activation=activation,
conv_state=conv_states_origin,
has_initial_state=has_initial_state_tensor,
non_spec_state_indices_tensor=cache_indices,
non_spec_query_start_loc=query_start_loc,
pad_slot_id=PAD_SLOT_ID,
).transpose(-1, -2)
validate_cmp(out, out_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])