[BugFix] fix qwen3-next compilation error (#7977)
### What this PR does / why we need it?
fix qwen3-next compilation error
- vLLM version: v0.18.0
- vLLM release0.18.0:
445dc7196f
---------
Signed-off-by: cvSoldier <610496306@qq.com>
This commit is contained in:
@@ -458,3 +458,61 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
|
def test_causal_conv1d_update_qwen3_next_shape():
|
||||||
|
device = "npu"
|
||||||
|
itype = torch.bfloat16
|
||||||
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||||
|
if itype == torch.bfloat16:
|
||||||
|
rtol, atol = 1e-2, 5e-2
|
||||||
|
|
||||||
|
total_tokens = 192
|
||||||
|
dim = 4096
|
||||||
|
kernel_size = 4
|
||||||
|
batch_size = 96
|
||||||
|
num_states = 929
|
||||||
|
|
||||||
|
x = torch.randn(total_tokens, dim, dtype=itype, device=device)
|
||||||
|
conv_state = torch.randn(num_states, dim, kernel_size, dtype=itype, device=device)
|
||||||
|
weight = torch.randn(dim, kernel_size, dtype=itype, device=device)
|
||||||
|
bias = None
|
||||||
|
conv_state_indices = torch.randint(0, num_states, (batch_size,), dtype=torch.int32, device=device)
|
||||||
|
num_accepted_tokens = torch.ones(total_tokens, dtype=torch.int32, device=device)
|
||||||
|
query_start_loc = torch.arange(0, total_tokens + 1, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
activation = "silu"
|
||||||
|
max_query_len = 2
|
||||||
|
pad_slot_id = -1
|
||||||
|
validate_data = False
|
||||||
|
|
||||||
|
block_idx_last_scheduled_token = None
|
||||||
|
initial_state_idx = None
|
||||||
|
|
||||||
|
out = causal_conv1d_update(
|
||||||
|
x,
|
||||||
|
conv_state,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
activation,
|
||||||
|
conv_state_indices,
|
||||||
|
num_accepted_tokens,
|
||||||
|
query_start_loc,
|
||||||
|
max_query_len,
|
||||||
|
pad_slot_id,
|
||||||
|
block_idx_last_scheduled_token,
|
||||||
|
initial_state_idx,
|
||||||
|
validate_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_ref = x.clone()
|
||||||
|
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||||
|
out_ref = causal_conv1d_update_ref(
|
||||||
|
x_ref[:batch_size].transpose(1, 2), conv_state_ref, weight, bias, activation=activation
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
torch.npu.reset_peak_memory_stats()
|
||||||
@@ -164,7 +164,6 @@ def extract_last_width(x, start_loc, width):
|
|||||||
"stride_x_seq",
|
"stride_x_seq",
|
||||||
"stride_x_token",
|
"stride_x_token",
|
||||||
"stride_conv_state_seq",
|
"stride_conv_state_seq",
|
||||||
"stride_conv_state_tok",
|
|
||||||
"stride_state_indices",
|
"stride_state_indices",
|
||||||
"stride_o_seq",
|
"stride_o_seq",
|
||||||
"stride_o_token",
|
"stride_o_token",
|
||||||
@@ -195,7 +194,7 @@ def _causal_conv1d_update_kernel_npu_tiled(
|
|||||||
stride_w_width: tl.constexpr,
|
stride_w_width: tl.constexpr,
|
||||||
stride_conv_state_seq,
|
stride_conv_state_seq,
|
||||||
stride_conv_state_dim: tl.constexpr,
|
stride_conv_state_dim: tl.constexpr,
|
||||||
stride_conv_state_tok,
|
stride_conv_state_tok: tl.constexpr,
|
||||||
stride_state_indices,
|
stride_state_indices,
|
||||||
stride_o_seq,
|
stride_o_seq,
|
||||||
stride_o_dim: tl.constexpr,
|
stride_o_dim: tl.constexpr,
|
||||||
|
|||||||
Reference in New Issue
Block a user