[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)

### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.

The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**

To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
Qi Mao
2026-03-22 23:09:23 +08:00
committed by GitHub
parent b2e71b7930
commit 9bf9b4b267
13 changed files with 824 additions and 21 deletions

View File

@@ -186,6 +186,8 @@ def chunk_gated_delta_rule_fwd_h(
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
@@ -193,15 +195,18 @@ def chunk_gated_delta_rule_fwd_h(
H = u.shape[-2]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
if cu_seqlens is not None and chunk_indices is None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
if chunk_offsets is None:
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
chunk_offsets,
)
assert K <= 256, "current kernel does not support head dimension larger than 256."