[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)

### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.

The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**

To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
Qi Mao
2026-03-22 23:09:23 +08:00
committed by GitHub
parent b2e71b7930
commit 9bf9b4b267
13 changed files with 824 additions and 21 deletions

View File

@@ -0,0 +1,389 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from types import SimpleNamespace
import pytest
import torch
import vllm_ascend.patch.worker.patch_gdn_attn as patch_gdn_attn
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.kv_cache_interface import MambaSpec
@dataclass
class BatchSpec:
seq_lens: list[int]
query_lens: list[int]
name: str = "unnamed"
@property
def batch_size(self):
return len(self.seq_lens)
def create_common_attn_metadata(
batch_spec: BatchSpec,
block_size: int,
device: torch.device,
) -> CommonAttentionMetadata:
query_start_loc = torch.zeros(
batch_spec.batch_size + 1,
dtype=torch.int32,
device=device,
)
query_start_loc[1:] = torch.tensor(
batch_spec.query_lens,
dtype=torch.int32,
device=device,
).cumsum(0)
query_start_loc_cpu = query_start_loc.cpu()
num_tokens = sum(batch_spec.query_lens)
seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
context_lens = [
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
for i in range(batch_spec.batch_size)
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
block_table_tensor = torch.arange(
batch_spec.batch_size * max_blocks,
dtype=torch.int32,
device=device,
).view(batch_spec.batch_size, max_blocks)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max(batch_spec.query_lens),
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
def _make_vllm_config(
*,
max_model_len: int = 8192,
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
num_heads: int = 32,
num_speculative_tokens: int = 0,
):
speculative_config = None
if num_speculative_tokens > 0:
speculative_config = SimpleNamespace(
num_speculative_tokens=num_speculative_tokens,
parallel_drafting=False,
)
model_config = SimpleNamespace(max_model_len=max_model_len)
model_config.get_num_attention_heads = lambda parallel_config: num_heads
return SimpleNamespace(
cache_config=SimpleNamespace(mamba_cache_mode="none"),
compilation_config=SimpleNamespace(
cudagraph_mode=CUDAGraphMode.NONE,
max_cudagraph_capture_size=None,
),
speculative_config=speculative_config,
scheduler_config=SimpleNamespace(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
),
parallel_config=SimpleNamespace(
decode_context_parallel_size=1,
tensor_parallel_size=1,
),
model_config=model_config,
)
def _make_builder(*, device: torch.device, num_heads: int, num_speculative_tokens: int):
vllm_config = _make_vllm_config(
num_heads=num_heads,
num_speculative_tokens=num_speculative_tokens,
)
spec = MambaSpec(
block_size=16,
shapes=((1,), (1,)),
dtypes=(torch.float32,),
mamba_cache_mode="none",
)
return GDNAttentionMetadataBuilder(spec, ["layer0"], vllm_config, device)
def _next_power_of_2(value: int) -> int:
if value <= 1:
return 1
return 1 << (value - 1).bit_length()
def _prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
pairs: list[list[int]] = []
for seq_idx, seq_len in enumerate(lens):
num_chunks = (seq_len + chunk_size - 1) // chunk_size
for chunk_idx in range(num_chunks):
pairs.append([seq_idx, chunk_idx])
if not pairs:
return torch.empty((0, 2), dtype=cu_seqlens.dtype, device=cu_seqlens.device)
return torch.tensor(pairs, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
def _prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
lens = cu_seqlens[1:] - cu_seqlens[:-1]
num_chunks = torch.div(
lens + chunk_size - 1,
chunk_size,
rounding_mode="floor",
)
offsets = torch.zeros(len(num_chunks) + 1, dtype=cu_seqlens.dtype)
torch.cumsum(num_chunks, dim=0, out=offsets[1:])
return offsets.to(cu_seqlens.device)
def _prepare_update_chunk_offsets(
cu_seqlens: torch.Tensor, chunk_size: int
) -> torch.Tensor:
lens = cu_seqlens[1:] - cu_seqlens[:-1]
num_chunks = torch.div(
lens + chunk_size - 1,
chunk_size,
rounding_mode="floor",
) + 1
offsets = torch.zeros(len(num_chunks) + 1, dtype=cu_seqlens.dtype)
torch.cumsum(num_chunks, dim=0, out=offsets[1:])
return offsets.to(cu_seqlens.device)
def _prepare_final_chunk_indices(
cu_seqlens: torch.Tensor, chunk_size: int
) -> torch.Tensor:
lens = cu_seqlens[1:] - cu_seqlens[:-1]
num_chunks = torch.div(
lens + chunk_size - 1,
chunk_size,
rounding_mode="floor",
) + 1
return (torch.cumsum(num_chunks, dim=0) - 1).to(cu_seqlens.device)
def _build_non_spec_query_start_loc_cpu(
batch_spec: BatchSpec, spec_mask_cpu: torch.Tensor | None
) -> torch.Tensor:
query_lens = torch.tensor(batch_spec.query_lens, dtype=torch.int32)
if spec_mask_cpu is not None:
query_lens = query_lens[~spec_mask_cpu]
query_start_loc = torch.zeros(query_lens.numel() + 1, dtype=torch.int32)
torch.cumsum(query_lens, dim=0, out=query_start_loc[1:])
return query_start_loc
@pytest.mark.parametrize(
("batch_spec", "num_speculative_tokens", "num_decode_draft_tokens_cpu"),
[
(
BatchSpec(
seq_lens=[8, 12],
query_lens=[4, 8],
name="pure_non_spec_prefill",
),
0,
None,
),
(
BatchSpec(
seq_lens=[8, 4, 0, 12],
query_lens=[4, 4, 0, 8],
name="mixed_spec_non_spec_with_padding",
),
3,
torch.tensor([-1, 3, -1, -1], dtype=torch.int32),
),
(
BatchSpec(
seq_lens=[5, 12, 0, 9],
query_lens=[1, 8, 0, 1],
name="mixed_prefill_decode_without_spec",
),
0,
None,
),
],
ids=lambda case: case.name if isinstance(case, BatchSpec) else None,
)
def test_builder_prebuilds_non_spec_chunk_metadata_exactly(
batch_spec: BatchSpec,
num_speculative_tokens: int,
num_decode_draft_tokens_cpu: torch.Tensor | None,
):
device = torch.device("cpu")
num_heads = 32
common_attn_metadata = create_common_attn_metadata(
batch_spec=batch_spec,
block_size=16,
device=device,
)
builder = _make_builder(
device=device,
num_heads=num_heads,
num_speculative_tokens=num_speculative_tokens,
)
num_accepted_tokens = None
spec_mask_cpu = None
if num_decode_draft_tokens_cpu is not None:
num_accepted_tokens = torch.ones(
batch_spec.batch_size, dtype=torch.int32, device=device
)
spec_mask_cpu = num_decode_draft_tokens_cpu >= 0
attn_metadata = builder.build(
0,
common_attn_metadata,
num_accepted_tokens=num_accepted_tokens,
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
)
non_spec_query_start_loc_cpu = _build_non_spec_query_start_loc_cpu(
batch_spec,
spec_mask_cpu,
)
legacy_chunk_indices_64 = _prepare_chunk_indices(non_spec_query_start_loc_cpu, 64)
legacy_chunk_offsets_64 = _prepare_chunk_offsets(non_spec_query_start_loc_cpu, 64)
legacy_update_chunk_offsets_64 = _prepare_update_chunk_offsets(
non_spec_query_start_loc_cpu,
64,
)
legacy_final_chunk_indices_64 = _prepare_final_chunk_indices(
non_spec_query_start_loc_cpu,
64,
)
legacy_chunk_indices_large_block = _prepare_chunk_indices(
non_spec_query_start_loc_cpu,
patch_gdn_attn._GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE,
)
optim_block_size = _next_power_of_2(
patch_gdn_attn._GDN_CUMSUM_WORKING_SET
// (num_heads * patch_gdn_attn._GDN_CHUNK_SIZE)
)
legacy_block_indices_cumsum = _prepare_chunk_indices(
non_spec_query_start_loc_cpu,
optim_block_size,
)
prebuilt_meta = getattr(attn_metadata, "non_spec_chunked_prefill_meta", None)
assert prebuilt_meta is not None
assert torch.equal(prebuilt_meta.chunk_indices_chunk64, legacy_chunk_indices_64)
assert torch.equal(prebuilt_meta.chunk_offsets_chunk64, legacy_chunk_offsets_64)
assert torch.equal(
prebuilt_meta.update_chunk_offsets_chunk64, legacy_update_chunk_offsets_64
)
assert torch.equal(
prebuilt_meta.final_chunk_indices_chunk64, legacy_final_chunk_indices_64
)
assert torch.equal(
prebuilt_meta.chunk_indices_large_block,
legacy_chunk_indices_large_block,
)
assert torch.equal(
prebuilt_meta.block_indices_cumsum,
legacy_block_indices_cumsum,
)
def test_allocate_chunked_prefill_slot_uses_cpugpubuffer(monkeypatch):
class DummyCpuGpuBuffer:
def __init__(
self,
*size,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
with_numpy: bool = True,
) -> None:
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu")
self.gpu = torch.zeros_like(self.cpu, device=device)
self.dtype = dtype
self.device = device
self.pin_memory = pin_memory
self.with_numpy = with_numpy
device = torch.device("cpu")
builder = _make_builder(
device=device,
num_heads=32,
num_speculative_tokens=0,
)
monkeypatch.setattr(patch_gdn_attn, "CpuGpuBuffer", DummyCpuGpuBuffer)
slot = patch_gdn_attn._allocate_chunked_prefill_slot(builder, device)
assert isinstance(slot.chunk_indices_chunk64, DummyCpuGpuBuffer)
assert isinstance(slot.chunk_offsets_chunk64, DummyCpuGpuBuffer)
assert isinstance(slot.update_chunk_offsets_chunk64, DummyCpuGpuBuffer)
assert isinstance(slot.final_chunk_indices_chunk64, DummyCpuGpuBuffer)
assert slot.chunk_indices_chunk64.pin_memory is True
assert slot.chunk_indices_chunk64.with_numpy is False
assert slot.chunk_indices_chunk64.device == device
assert slot.chunk_indices_chunk64.cpu.shape == (
builder.vllm_config.scheduler_config.max_num_batched_tokens,
2,
)
assert slot.chunk_indices_chunk64.gpu.shape == (
builder.vllm_config.scheduler_config.max_num_batched_tokens,
2,
)
@pytest.mark.parametrize(
"batch_spec",
[
BatchSpec(seq_lens=[1, 1, 1], query_lens=[1, 1, 1], name="decode_only"),
BatchSpec(seq_lens=[4, 4], query_lens=[4, 4], name="spec_only"),
],
)
def test_builder_skips_prebuilt_meta_without_non_spec_prefill(batch_spec: BatchSpec):
device = torch.device("cpu")
builder = _make_builder(
device=device,
num_heads=32,
num_speculative_tokens=3 if batch_spec.name == "spec_only" else 0,
)
common_attn_metadata = create_common_attn_metadata(
batch_spec=batch_spec,
block_size=16,
device=device,
)
num_accepted_tokens = None
num_decode_draft_tokens_cpu = None
if batch_spec.name == "spec_only":
num_accepted_tokens = torch.ones(
batch_spec.batch_size, dtype=torch.int32, device=device
)
num_decode_draft_tokens_cpu = torch.full(
(batch_spec.batch_size,), 3, dtype=torch.int32
)
attn_metadata = builder.build(
0,
common_attn_metadata,
num_accepted_tokens=num_accepted_tokens,
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
)
assert getattr(attn_metadata, "non_spec_chunked_prefill_meta", None) is None