[BugFix] fix qwen3-next compilation error (#7977)

### What this PR does / why we need it?
fix qwen3-next compilation error

- vLLM version: v0.18.0
- vLLM release0.18.0:
445dc7196f
---------
Signed-off-by: cvSoldier <610496306@qq.com>
This commit is contained in:
cvSoldier
2026-04-03 20:03:39 +08:00
committed by GitHub
parent 81c6f51a45
commit 6c19270498
2 changed files with 59 additions and 2 deletions

View File

@@ -458,3 +458,61 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
def test_causal_conv1d_update_qwen3_next_shape():
device = "npu"
itype = torch.bfloat16
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
total_tokens = 192
dim = 4096
kernel_size = 4
batch_size = 96
num_states = 929
x = torch.randn(total_tokens, dim, dtype=itype, device=device)
conv_state = torch.randn(num_states, dim, kernel_size, dtype=itype, device=device)
weight = torch.randn(dim, kernel_size, dtype=itype, device=device)
bias = None
conv_state_indices = torch.randint(0, num_states, (batch_size,), dtype=torch.int32, device=device)
num_accepted_tokens = torch.ones(total_tokens, dtype=torch.int32, device=device)
query_start_loc = torch.arange(0, total_tokens + 1, dtype=torch.int32, device=device)
activation = "silu"
max_query_len = 2
pad_slot_id = -1
validate_data = False
block_idx_last_scheduled_token = None
initial_state_idx = None
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
max_query_len,
pad_slot_id,
block_idx_last_scheduled_token,
initial_state_idx,
validate_data,
)
x_ref = x.clone()
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size].transpose(1, 2), conv_state_ref, weight, bias, activation=activation
).transpose(1, 2)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -164,7 +164,6 @@ def extract_last_width(x, start_loc, width):
"stride_x_seq",
"stride_x_token",
"stride_conv_state_seq",
"stride_conv_state_tok",
"stride_state_indices",
"stride_o_seq",
"stride_o_token",
@@ -195,7 +194,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
stride_w_width: tl.constexpr,
stride_conv_state_seq,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok,
stride_conv_state_tok: tl.constexpr,
stride_state_indices,
stride_o_seq,
stride_o_dim: tl.constexpr,