[BugFix] fix qwen3-next compilation error (#7977)
### What this PR does / why we need it?
fix qwen3-next compilation error
- vLLM version: v0.18.0
- vLLM release0.18.0:
445dc7196f
---------
Signed-off-by: cvSoldier <610496306@qq.com>
This commit is contained in:
@@ -458,3 +458,61 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
def test_causal_conv1d_update_qwen3_next_shape():
|
||||
device = "npu"
|
||||
itype = torch.bfloat16
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
total_tokens = 192
|
||||
dim = 4096
|
||||
kernel_size = 4
|
||||
batch_size = 96
|
||||
num_states = 929
|
||||
|
||||
x = torch.randn(total_tokens, dim, dtype=itype, device=device)
|
||||
conv_state = torch.randn(num_states, dim, kernel_size, dtype=itype, device=device)
|
||||
weight = torch.randn(dim, kernel_size, dtype=itype, device=device)
|
||||
bias = None
|
||||
conv_state_indices = torch.randint(0, num_states, (batch_size,), dtype=torch.int32, device=device)
|
||||
num_accepted_tokens = torch.ones(total_tokens, dtype=torch.int32, device=device)
|
||||
query_start_loc = torch.arange(0, total_tokens + 1, dtype=torch.int32, device=device)
|
||||
|
||||
activation = "silu"
|
||||
max_query_len = 2
|
||||
pad_slot_id = -1
|
||||
validate_data = False
|
||||
|
||||
block_idx_last_scheduled_token = None
|
||||
initial_state_idx = None
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
query_start_loc,
|
||||
max_query_len,
|
||||
pad_slot_id,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
validate_data,
|
||||
)
|
||||
|
||||
x_ref = x.clone()
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size].transpose(1, 2), conv_state_ref, weight, bias, activation=activation
|
||||
).transpose(1, 2)
|
||||
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -164,7 +164,6 @@ def extract_last_width(x, start_loc, width):
|
||||
"stride_x_seq",
|
||||
"stride_x_token",
|
||||
"stride_conv_state_seq",
|
||||
"stride_conv_state_tok",
|
||||
"stride_state_indices",
|
||||
"stride_o_seq",
|
||||
"stride_o_token",
|
||||
@@ -195,7 +194,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices,
|
||||
stride_o_seq,
|
||||
stride_o_dim: tl.constexpr,
|
||||
|
||||
Reference in New Issue
Block a user