[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)

### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.

The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**

To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
Qi Mao
2026-03-22 23:09:23 +08:00
committed by GitHub
parent b2e71b7930
commit 9bf9b4b267
13 changed files with 824 additions and 21 deletions

View File

@@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
prebuilt_meta=None,
):
forward_context = get_forward_context()
num_decodes = 0
@@ -47,10 +48,34 @@ def chunk_gated_delta_rule_fwd(
if attn_metadata is not None:
num_decodes = attn_metadata.num_decodes
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
block_indices_cumsum = None if prebuilt_meta is None else prebuilt_meta.block_indices_cumsum
chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_chunk64
chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_offsets_chunk64
update_chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.update_chunk_offsets_chunk64
final_chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.final_chunk_indices_chunk64
chunk_indices_large_block = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_large_block
g = chunk_local_cumsum(
g,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
block_indices=block_indices_cumsum,
)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
A = chunk_scaled_dot_kkt_fwd(
k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices_chunk64,
output_dtype=torch.float32,
)
A = solve_tril(
A=A,
cu_seqlens=cu_seqlens,
chunk_indices_large_block=chunk_indices_large_block,
chunk_indices_bt=chunk_indices_chunk64,
output_dtype=k.dtype,
)
w, u = recompute_w_u_fwd(
k=k,
v=v,
@@ -58,6 +83,7 @@ def chunk_gated_delta_rule_fwd(
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices_chunk64,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
@@ -67,6 +93,8 @@ def chunk_gated_delta_rule_fwd(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices_chunk64,
chunk_offsets=chunk_offsets_chunk64,
)
if get_pcp_group().world_size > 1:
@@ -76,10 +104,15 @@ def chunk_gated_delta_rule_fwd(
u=u,
g=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices_chunk64,
chunk_offsets=chunk_offsets_chunk64,
update_chunk_offsets=update_chunk_offsets_chunk64,
num_decodes=num_decodes,
)
all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0)
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
final_chunk_indices = final_chunk_indices_chunk64
if final_chunk_indices is None:
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
final_h_update = h_update[:, final_chunk_indices, :, :, :]
all_final_h_update = get_pcp_group().all_gather(final_h_update, 0)
@@ -137,6 +170,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
prebuilt_meta=None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -152,6 +186,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
prebuilt_meta=prebuilt_meta,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
@@ -169,6 +204,7 @@ def chunk_gated_delta_rule(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
prebuilt_meta=None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
@@ -268,7 +304,17 @@ def chunk_gated_delta_rule(
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
prebuilt_meta,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, "b t h ... -> b h t ...")