[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)

### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.

The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**

To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
Qi Mao
2026-03-22 23:09:23 +08:00
committed by GitHub
parent b2e71b7930
commit 9bf9b4b267
13 changed files with 824 additions and 21 deletions

View File

@@ -163,6 +163,9 @@ def chunk_gated_delta_rule_fwd_hupdate(
g: torch.Tensor | None = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
update_chunk_offsets: torch.Tensor | None = None,
num_decodes: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
@@ -171,20 +174,25 @@ def chunk_gated_delta_rule_fwd_hupdate(
H = u.shape[-2]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
if cu_seqlens is not None and chunk_indices is None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
if chunk_offsets is None:
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
chunk_offsets,
)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h_update = k.new_empty(B, NT + N, H, K, K, dtype=torch.float32)
update_indices = prepare_update_chunk_offsets(cu_seqlens, BT)[:-1]
if cu_seqlens is not None and update_chunk_offsets is None:
update_chunk_offsets = prepare_update_chunk_offsets(cu_seqlens, BT)
update_indices = update_chunk_offsets[:-1]
h_update[:, update_indices, :, :, :] = torch.eye(K, dtype=h_update.dtype, device=h_update.device)
g = g.transpose(1, 2).contiguous()