[qwen3 next ]add ascend c casual_conv1d_fn (#6661)
### What this PR does / why we need it?
add ascend c casual_conv1d_fn
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -92,10 +92,15 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||
@pytest.mark.skip("Skip this CI.")
|
||||
def test_qwen3_next_mtp_correctness_tp4(model_name: str,
|
||||
num_speculative_tokens: int,
|
||||
disable_padded_drafter_batch: bool):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
|
||||
@@ -8,7 +8,7 @@ from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
|
||||
causal_conv1d_fn)
|
||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
|
||||
causal_conv1d_update_npu as causal_conv1d_update
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
||||
y_cal = y_cal.to(device)
|
||||
@@ -157,6 +157,90 @@ def causal_conv1d_fn_pytorch(
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('silu_activation', [True])
|
||||
@pytest.mark.parametrize('has_bias', [True])
|
||||
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
|
||||
@pytest.mark.parametrize('extra_state_len', [0, 2])
|
||||
@pytest.mark.parametrize('width', [4])
|
||||
@pytest.mark.parametrize('dim', [2048])
|
||||
def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
|
||||
silu_activation, itype, has_initial_state):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
enable_custom_op()
|
||||
device = "npu"
|
||||
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
|
||||
state_len = width - 1 + extra_state_len
|
||||
|
||||
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)#
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
|
||||
device=device,
|
||||
dtype=torch.int32),
|
||||
dim=0).to(dtype=torch.int32)
|
||||
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
|
||||
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
|
||||
device=device,
|
||||
dtype=torch.bool)
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
if has_initial_state:
|
||||
conv_states = torch.randn((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
conv_states_ref = torch.randn(
|
||||
(num_seq, state_len, dim), device=device,
|
||||
dtype=itype).transpose(-1, -2).copy_(conv_states)
|
||||
else:
|
||||
conv_states = torch.zeros((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
conv_states_ref = torch.zeros((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
out_ref = causal_conv1d_fn_pytorch(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
activation=activation,
|
||||
conv_states=conv_states_ref,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc)
|
||||
# out = causal_conv1d_fn(x,
|
||||
# weight,
|
||||
# bias=bias,
|
||||
# activation=activation,
|
||||
# conv_states=conv_states,
|
||||
# has_initial_state=has_initial_state_tensor,
|
||||
# cache_indices=cache_indices,
|
||||
# query_start_loc=query_start_loc)
|
||||
x_origin=x.transpose(-1, -2)
|
||||
weight_origin=weight.transpose(-1, -2)
|
||||
conv_states_origin=conv_states.transpose(-1, -2)
|
||||
out = torch.ops._C_ascend.causal_conv1d_fn(
|
||||
x_origin,
|
||||
weight_origin,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state=conv_states_origin,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
non_spec_state_indices_tensor=cache_indices,
|
||||
non_spec_query_start_loc=query_start_loc,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
).transpose(-1, -2)
|
||||
validate_cmp(out, out_ref, itype)
|
||||
validate_cmp(conv_states, conv_states_ref, itype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('silu_activation', [True])
|
||||
|
||||
Reference in New Issue
Block a user