[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)
### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.
The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**
To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
prebuilt_meta=None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
num_decodes = 0
|
||||
@@ -47,10 +48,34 @@ def chunk_gated_delta_rule_fwd(
|
||||
if attn_metadata is not None:
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
block_indices_cumsum = None if prebuilt_meta is None else prebuilt_meta.block_indices_cumsum
|
||||
chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_chunk64
|
||||
chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_offsets_chunk64
|
||||
update_chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.update_chunk_offsets_chunk64
|
||||
final_chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.final_chunk_indices_chunk64
|
||||
chunk_indices_large_block = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_large_block
|
||||
g = chunk_local_cumsum(
|
||||
g,
|
||||
chunk_size=chunk_size,
|
||||
cu_seqlens=cu_seqlens,
|
||||
block_indices=block_indices_cumsum,
|
||||
)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
A = chunk_scaled_dot_kkt_fwd(
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
A = solve_tril(
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices_large_block=chunk_indices_large_block,
|
||||
chunk_indices_bt=chunk_indices_chunk64,
|
||||
output_dtype=k.dtype,
|
||||
)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -58,6 +83,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
@@ -67,6 +93,8 @@ def chunk_gated_delta_rule_fwd(
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
chunk_offsets=chunk_offsets_chunk64,
|
||||
)
|
||||
|
||||
if get_pcp_group().world_size > 1:
|
||||
@@ -76,10 +104,15 @@ def chunk_gated_delta_rule_fwd(
|
||||
u=u,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
chunk_offsets=chunk_offsets_chunk64,
|
||||
update_chunk_offsets=update_chunk_offsets_chunk64,
|
||||
num_decodes=num_decodes,
|
||||
)
|
||||
all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0)
|
||||
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
|
||||
final_chunk_indices = final_chunk_indices_chunk64
|
||||
if final_chunk_indices is None:
|
||||
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
|
||||
final_h_update = h_update[:, final_chunk_indices, :, :, :]
|
||||
all_final_h_update = get_pcp_group().all_gather(final_h_update, 0)
|
||||
|
||||
@@ -137,6 +170,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
prebuilt_meta=None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -152,6 +186,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
prebuilt_meta=prebuilt_meta,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
@@ -169,6 +204,7 @@ def chunk_gated_delta_rule(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
prebuilt_meta=None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
@@ -268,7 +304,17 @@ def chunk_gated_delta_rule(
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
prebuilt_meta,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
|
||||
@@ -186,6 +186,8 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
@@ -193,15 +195,18 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
if chunk_offsets is None:
|
||||
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
chunk_offsets,
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
|
||||
@@ -163,6 +163,9 @@ def chunk_gated_delta_rule_fwd_hupdate(
|
||||
g: torch.Tensor | None = None,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
update_chunk_offsets: torch.Tensor | None = None,
|
||||
num_decodes: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
@@ -171,20 +174,25 @@ def chunk_gated_delta_rule_fwd_hupdate(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
if chunk_offsets is None:
|
||||
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
chunk_offsets,
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h_update = k.new_empty(B, NT + N, H, K, K, dtype=torch.float32)
|
||||
update_indices = prepare_update_chunk_offsets(cu_seqlens, BT)[:-1]
|
||||
if cu_seqlens is not None and update_chunk_offsets is None:
|
||||
update_chunk_offsets = prepare_update_chunk_offsets(cu_seqlens, BT)
|
||||
update_indices = update_chunk_offsets[:-1]
|
||||
h_update[:, update_indices, :, :, :] = torch.eye(K, dtype=h_update.dtype, device=h_update.device)
|
||||
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
|
||||
@@ -85,6 +85,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@@ -115,13 +116,8 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
chunk_indices = chunk_indices.npu()
|
||||
cu_seqlens = cu_seqlens.npu()
|
||||
else:
|
||||
chunk_indices = None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ def chunk_local_cumsum_scalar(
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
block_indices: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.Tensor | None = torch.float,
|
||||
):
|
||||
@@ -90,7 +91,8 @@ def chunk_local_cumsum_scalar(
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
|
||||
OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and block_indices is None:
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE)
|
||||
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (num_blocks, B)
|
||||
@@ -132,6 +134,7 @@ def chunk_local_cumsum(
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
block_indices=kwargs.get("block_indices"),
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
@@ -330,6 +330,8 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices_large_block: torch.Tensor | None = None,
|
||||
chunk_indices_bt: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -355,7 +357,9 @@ def solve_tril(
|
||||
|
||||
LARGE_BLOCK_T = 608 * 2
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices_large_block is None:
|
||||
chunk_indices_large_block = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
|
||||
chunk_indices = chunk_indices_large_block
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
|
||||
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
@@ -376,7 +380,9 @@ def solve_tril(
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices_bt is None:
|
||||
chunk_indices_bt = prepare_chunk_indices(cu_seqlens, BT)
|
||||
chunk_indices = chunk_indices_bt
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
merge_fn[NT, B * H](
|
||||
|
||||
@@ -102,12 +102,14 @@ def recompute_w_u_fwd(
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
BK = 64
|
||||
|
||||
Reference in New Issue
Block a user