Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_gdn_attn.py
Qi Mao 9bf9b4b267 [Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)
### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.

The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**

To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
2026-03-22 23:09:23 +08:00

322 lines
13 KiB
Python

# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
import vllm.v1.attention.backends.gdn_attn as gdn_attn
from vllm.v1.utils import CpuGpuBuffer
_GDN_CHUNK_SIZE = 64
# Keep this aligned with solve_tril.LARGE_BLOCK_T in ops/triton/fla/solve_tril.py.
_GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE = 608 * 2
_GDN_CUMSUM_WORKING_SET = 2**18
_IS_PATCHED = False
_ORIGINAL_BUILD = gdn_attn.GDNAttentionMetadataBuilder.build
@dataclass
class GDNChunkedPrefillMetadata:
chunk_indices_chunk64: torch.Tensor
chunk_offsets_chunk64: torch.Tensor
update_chunk_offsets_chunk64: torch.Tensor
final_chunk_indices_chunk64: torch.Tensor
chunk_indices_large_block: torch.Tensor
block_indices_cumsum: torch.Tensor
_buffer_slot: object | None = None
@dataclass
class _GDNChunkedPrefillBufferSlot:
chunk_indices_chunk64: CpuGpuBuffer
chunk_offsets_chunk64: CpuGpuBuffer
update_chunk_offsets_chunk64: CpuGpuBuffer
final_chunk_indices_chunk64: CpuGpuBuffer
chunk_indices_large_block: CpuGpuBuffer
block_indices_cumsum: CpuGpuBuffer
def _next_power_of_2(value: int) -> int:
if value <= 1:
return 1
return 1 << (value - 1).bit_length()
def _prepare_chunk_counts_cpu(cu_seqlens_cpu: torch.Tensor, chunk_size: int) -> torch.Tensor:
lens = cu_seqlens_cpu[1:] - cu_seqlens_cpu[:-1]
return torch.div(lens + chunk_size - 1, chunk_size, rounding_mode="floor")
def _fill_chunk_indices_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
cursor = 0
for seq_idx, num_chunks in enumerate(chunk_counts.tolist()):
if num_chunks <= 0:
continue
out[cursor : cursor + num_chunks, 0].fill_(seq_idx)
out[cursor : cursor + num_chunks, 1] = torch.arange(
num_chunks,
dtype=out.dtype,
)
cursor += num_chunks
return cursor
def _fill_chunk_offsets_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
out[0] = 0
if chunk_counts.numel() > 0:
torch.cumsum(chunk_counts, dim=0, out=out[1 : chunk_counts.numel() + 1])
return chunk_counts.numel() + 1
def _fill_update_chunk_offsets_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
out[0] = 0
if chunk_counts.numel() > 0:
torch.cumsum(
chunk_counts + 1,
dim=0,
out=out[1 : chunk_counts.numel() + 1],
)
return chunk_counts.numel() + 1
def _fill_final_chunk_indices_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
if chunk_counts.numel() > 0:
torch.cumsum(chunk_counts + 1, dim=0, out=out[: chunk_counts.numel()])
out[: chunk_counts.numel()].sub_(1)
return chunk_counts.numel()
def _get_gdn_num_heads(builder) -> int:
hf_text_config = getattr(builder.vllm_config.model_config, "hf_text_config", None)
if hf_text_config is not None and hasattr(hf_text_config, "linear_num_value_heads"):
return hf_text_config.linear_num_value_heads // builder.vllm_config.parallel_config.tensor_parallel_size
return builder.vllm_config.model_config.get_num_attention_heads(builder.vllm_config.parallel_config)
def _allocate_chunked_prefill_slot(builder, device: torch.device):
max_num_batched_tokens = builder.vllm_config.scheduler_config.max_num_batched_tokens
max_num_seqs = builder.vllm_config.scheduler_config.max_num_seqs
return _GDNChunkedPrefillBufferSlot(
chunk_indices_chunk64=CpuGpuBuffer(
max_num_batched_tokens,
2,
dtype=torch.int32,
device=device,
pin_memory=True,
with_numpy=False,
),
chunk_offsets_chunk64=CpuGpuBuffer(
max_num_seqs + 1,
dtype=torch.int32,
device=device,
pin_memory=True,
with_numpy=False,
),
update_chunk_offsets_chunk64=CpuGpuBuffer(
max_num_seqs + 1,
dtype=torch.int32,
device=device,
pin_memory=True,
with_numpy=False,
),
final_chunk_indices_chunk64=CpuGpuBuffer(
max_num_seqs,
dtype=torch.int32,
device=device,
pin_memory=True,
with_numpy=False,
),
chunk_indices_large_block=CpuGpuBuffer(
max_num_batched_tokens,
2,
dtype=torch.int32,
device=device,
pin_memory=True,
with_numpy=False,
),
block_indices_cumsum=CpuGpuBuffer(
max_num_batched_tokens,
2,
dtype=torch.int32,
device=device,
pin_memory=True,
with_numpy=False,
),
)
def _ensure_chunk_meta_state(builder, device: torch.device) -> None:
if getattr(builder, "_ascend_gdn_chunk_meta_initialized", False):
return
builder._ascend_gdn_chunk_meta_initialized = True
builder._ascend_gdn_chunk_meta_device = device
builder._ascend_gdn_chunk_size = _GDN_CHUNK_SIZE
builder._ascend_gdn_large_block_size = _GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE
gdn_num_heads = _get_gdn_num_heads(builder)
cumsum_chunks = max(1, _GDN_CUMSUM_WORKING_SET // (gdn_num_heads * builder._ascend_gdn_chunk_size))
builder._ascend_gdn_cumsum_block_size = _next_power_of_2(cumsum_chunks)
builder._ascend_gdn_chunked_prefill_pool_idx = -1
builder._ascend_gdn_chunked_prefill_pool = []
if device.type != "cpu":
builder._ascend_gdn_chunked_prefill_pool = [
_allocate_chunked_prefill_slot(builder, device),
_allocate_chunked_prefill_slot(builder, device),
]
def _build_non_spec_query_start_loc_cpu(
builder,
attn_metadata,
common_attn_metadata,
num_decode_draft_tokens_cpu: torch.Tensor | None,
) -> torch.Tensor | None:
if attn_metadata.num_prefills <= 0:
return None
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
if (
not getattr(builder, "use_spec_decode", False)
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0].sum().item() == 0
):
return query_start_loc_cpu
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
if spec_sequence_masks_cpu.sum().item() == 0:
return query_start_loc_cpu
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu]
non_spec_query_start_loc_cpu = torch.zeros(
non_spec_query_lens_cpu.numel() + 1,
dtype=query_start_loc_cpu.dtype,
)
torch.cumsum(
non_spec_query_lens_cpu,
dim=0,
out=non_spec_query_start_loc_cpu[1:],
)
return non_spec_query_start_loc_cpu
def _build_non_spec_chunked_prefill_meta_cpu(builder, cu_seqlens_cpu: torch.Tensor) -> GDNChunkedPrefillMetadata:
chunk_counts_chunk64 = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_chunk_size)
chunk_counts_large = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_large_block_size)
chunk_counts_cumsum = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_cumsum_block_size)
num_seqs = chunk_counts_chunk64.numel()
chunk_indices_chunk64 = torch.empty((int(chunk_counts_chunk64.sum().item()), 2), dtype=torch.int32)
chunk_offsets_chunk64 = torch.empty((num_seqs + 1,), dtype=torch.int32)
update_chunk_offsets_chunk64 = torch.empty((num_seqs + 1,), dtype=torch.int32)
final_chunk_indices_chunk64 = torch.empty((num_seqs,), dtype=torch.int32)
chunk_indices_large_block = torch.empty((int(chunk_counts_large.sum().item()), 2), dtype=torch.int32)
block_indices_cumsum = torch.empty((int(chunk_counts_cumsum.sum().item()), 2), dtype=torch.int32)
_fill_chunk_indices_cpu(chunk_indices_chunk64, chunk_counts_chunk64)
_fill_chunk_offsets_cpu(chunk_offsets_chunk64, chunk_counts_chunk64)
_fill_update_chunk_offsets_cpu(update_chunk_offsets_chunk64, chunk_counts_chunk64)
_fill_final_chunk_indices_cpu(final_chunk_indices_chunk64, chunk_counts_chunk64)
_fill_chunk_indices_cpu(chunk_indices_large_block, chunk_counts_large)
_fill_chunk_indices_cpu(block_indices_cumsum, chunk_counts_cumsum)
return GDNChunkedPrefillMetadata(
chunk_indices_chunk64=chunk_indices_chunk64.to(builder._ascend_gdn_chunk_meta_device),
chunk_offsets_chunk64=chunk_offsets_chunk64.to(builder._ascend_gdn_chunk_meta_device),
update_chunk_offsets_chunk64=update_chunk_offsets_chunk64.to(builder._ascend_gdn_chunk_meta_device),
final_chunk_indices_chunk64=final_chunk_indices_chunk64.to(builder._ascend_gdn_chunk_meta_device),
chunk_indices_large_block=chunk_indices_large_block.to(builder._ascend_gdn_chunk_meta_device),
block_indices_cumsum=block_indices_cumsum.to(builder._ascend_gdn_chunk_meta_device),
)
def _build_non_spec_chunked_prefill_meta(builder, cu_seqlens_cpu: torch.Tensor) -> GDNChunkedPrefillMetadata:
device = builder._ascend_gdn_chunk_meta_device
if device.type == "cpu":
return _build_non_spec_chunked_prefill_meta_cpu(builder, cu_seqlens_cpu)
builder._ascend_gdn_chunked_prefill_pool_idx = (builder._ascend_gdn_chunked_prefill_pool_idx + 1) % len(
builder._ascend_gdn_chunked_prefill_pool
)
slot = builder._ascend_gdn_chunked_prefill_pool[builder._ascend_gdn_chunked_prefill_pool_idx]
chunk_counts_chunk64 = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_chunk_size)
chunk_counts_large = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_large_block_size)
chunk_counts_cumsum = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_cumsum_block_size)
num_chunk_indices_chunk64 = _fill_chunk_indices_cpu(slot.chunk_indices_chunk64.cpu, chunk_counts_chunk64)
num_chunk_offsets_chunk64 = _fill_chunk_offsets_cpu(slot.chunk_offsets_chunk64.cpu, chunk_counts_chunk64)
num_update_chunk_offsets_chunk64 = _fill_update_chunk_offsets_cpu(
slot.update_chunk_offsets_chunk64.cpu, chunk_counts_chunk64
)
num_final_chunk_indices_chunk64 = _fill_final_chunk_indices_cpu(
slot.final_chunk_indices_chunk64.cpu, chunk_counts_chunk64
)
num_chunk_indices_large = _fill_chunk_indices_cpu(slot.chunk_indices_large_block.cpu, chunk_counts_large)
num_block_indices_cumsum = _fill_chunk_indices_cpu(slot.block_indices_cumsum.cpu, chunk_counts_cumsum)
chunk_indices_chunk64 = slot.chunk_indices_chunk64.copy_to_gpu(num_chunk_indices_chunk64)
chunk_offsets_chunk64 = slot.chunk_offsets_chunk64.copy_to_gpu(num_chunk_offsets_chunk64)
update_chunk_offsets_chunk64 = slot.update_chunk_offsets_chunk64.copy_to_gpu(num_update_chunk_offsets_chunk64)
final_chunk_indices_chunk64 = slot.final_chunk_indices_chunk64.copy_to_gpu(num_final_chunk_indices_chunk64)
chunk_indices_large_block = slot.chunk_indices_large_block.copy_to_gpu(num_chunk_indices_large)
block_indices_cumsum = slot.block_indices_cumsum.copy_to_gpu(num_block_indices_cumsum)
return GDNChunkedPrefillMetadata(
chunk_indices_chunk64=chunk_indices_chunk64,
chunk_offsets_chunk64=chunk_offsets_chunk64,
update_chunk_offsets_chunk64=update_chunk_offsets_chunk64,
final_chunk_indices_chunk64=final_chunk_indices_chunk64,
chunk_indices_large_block=chunk_indices_large_block,
block_indices_cumsum=block_indices_cumsum,
_buffer_slot=slot,
)
def _patched_build(
self,
common_prefix_len: int,
common_attn_metadata,
num_accepted_tokens: torch.Tensor | None = None,
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
fast_build: bool = False,
):
attn_metadata = _ORIGINAL_BUILD(
self,
common_prefix_len,
common_attn_metadata,
num_accepted_tokens=num_accepted_tokens,
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
fast_build=fast_build,
)
attn_metadata.non_spec_chunked_prefill_meta = None
if attn_metadata.num_prefills <= 0:
return attn_metadata
_ensure_chunk_meta_state(self, common_attn_metadata.query_start_loc.device)
non_spec_query_start_loc_cpu = _build_non_spec_query_start_loc_cpu(
self,
attn_metadata,
common_attn_metadata,
num_decode_draft_tokens_cpu,
)
assert non_spec_query_start_loc_cpu is not None
attn_metadata.non_spec_chunked_prefill_meta = _build_non_spec_chunked_prefill_meta(
self, non_spec_query_start_loc_cpu
)
return attn_metadata
if not _IS_PATCHED:
gdn_attn.GDNChunkedPrefillMetadata = GDNChunkedPrefillMetadata
gdn_attn.GDNAttentionMetadataBuilder.build = _patched_build
_IS_PATCHED = True