[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)
### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.
The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**
To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
@@ -81,6 +81,7 @@ def chunk_local_cumsum_scalar(
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
block_indices: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.Tensor | None = torch.float,
|
||||
):
|
||||
@@ -90,7 +91,8 @@ def chunk_local_cumsum_scalar(
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
|
||||
OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and block_indices is None:
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE)
|
||||
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (num_blocks, B)
|
||||
@@ -132,6 +134,7 @@ def chunk_local_cumsum(
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
block_indices=kwargs.get("block_indices"),
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user