[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)

### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.

The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**

To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
Qi Mao
2026-03-22 23:09:23 +08:00
committed by GitHub
parent b2e71b7930
commit 9bf9b4b267
13 changed files with 824 additions and 21 deletions

View File

@@ -330,6 +330,8 @@ def merge_16x16_to_64x64_inverse_kernel(
def solve_tril(
A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
chunk_indices_large_block: torch.Tensor | None = None,
chunk_indices_bt: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
@@ -355,7 +357,9 @@ def solve_tril(
LARGE_BLOCK_T = 608 * 2
chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None
if cu_seqlens is not None and chunk_indices_large_block is None:
chunk_indices_large_block = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
chunk_indices = chunk_indices_large_block
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
solve_tril_16x16_kernel[NT, B * H](
@@ -376,7 +380,9 @@ def solve_tril(
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
if cu_seqlens is not None and chunk_indices_bt is None:
chunk_indices_bt = prepare_chunk_indices(cu_seqlens, BT)
chunk_indices = chunk_indices_bt
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
merge_fn[NT, B * H](