diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py index 88fc9591..655d18b5 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py @@ -458,3 +458,61 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() + + +def test_causal_conv1d_update_qwen3_next_shape(): + device = "npu" + itype = torch.bfloat16 + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + total_tokens = 192 + dim = 4096 + kernel_size = 4 + batch_size = 96 + num_states = 929 + + x = torch.randn(total_tokens, dim, dtype=itype, device=device) + conv_state = torch.randn(num_states, dim, kernel_size, dtype=itype, device=device) + weight = torch.randn(dim, kernel_size, dtype=itype, device=device) + bias = None + conv_state_indices = torch.randint(0, num_states, (batch_size,), dtype=torch.int32, device=device) + num_accepted_tokens = torch.ones(total_tokens, dtype=torch.int32, device=device) + query_start_loc = torch.arange(0, total_tokens + 1, dtype=torch.int32, device=device) + + activation = "silu" + max_query_len = 2 + pad_slot_id = -1 + validate_data = False + + block_idx_last_scheduled_token = None + initial_state_idx = None + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + max_query_len, + pad_slot_id, + block_idx_last_scheduled_token, + initial_state_idx, + validate_data, + ) + + x_ref = x.clone() + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + out_ref = causal_conv1d_update_ref( + x_ref[:batch_size].transpose(1, 2), conv_state_ref, weight, bias, activation=activation + ).transpose(1, 2) + + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() \ No newline at end of file diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index da7f4183..4c9a0c00 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -164,7 +164,6 @@ def extract_last_width(x, start_loc, width): "stride_x_seq", "stride_x_token", "stride_conv_state_seq", - "stride_conv_state_tok", "stride_state_indices", "stride_o_seq", "stride_o_token", @@ -195,7 +194,7 @@ def _causal_conv1d_update_kernel_npu_tiled( stride_w_width: tl.constexpr, stride_conv_state_seq, stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok, + stride_conv_state_tok: tl.constexpr, stride_state_indices, stride_o_seq, stride_o_dim: tl.constexpr,