[Feature] Optimize Qwen3.5/Qwen3Next GDN prefill by prebuilding chunk metadata (#7487)
### What this PR does / why we need it?
This PR optimizes the Qwen3.5 and Qwen3Next GDN prefill path on Ascend
by reducing host/device synchronization overhead.
The current implementation of the `chunk_gated_delta_rule` path for
variable-length sequences prepares chunk metadata during the forward
pass. This approach triggers frequent CPU intervention and host/device
round-trips. When running prefill-heavy workloads with asynchronous
scheduling enabled, these synchronizations result in execution "bubbles"
and prefill stalling (stuttering). **Note that this does not cause
asynchronous scheduling to fail; rather, it prevents the system from
reaching its theoretical throughput due to these unnecessary stalls.**
To resolve this, the patch moves metadata preparation out of the hot
path:
- **Prebuilt Metadata:** All non-speculative varlen chunk metadata for
GDN is now prebuilt on the CPU.
- **Asynchronous Transfer:** Staging buffers are kept in pinned memory
and transferred to the NPU asynchronously.
- **Integration:** The prebuilt bundle is attached to GDN attention
metadata via `patch_gdn_attn.py` and passed into Triton wrappers.
- **Backward Compatibility:** Triton wrappers fall back to the legacy
preparation path if no prebuilt metadata is provided.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
---------
Signed-off-by: maoxx241 <maomaoyu870@gmail.com>
This commit is contained in:
389
tests/ut/patch/worker/patch_common/test_patch_gdn_attn.py
Normal file
389
tests/ut/patch/worker/patch_common/test_patch_gdn_attn.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm_ascend.patch.worker.patch_gdn_attn as patch_gdn_attn
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSpec:
|
||||
seq_lens: list[int]
|
||||
query_lens: list[int]
|
||||
name: str = "unnamed"
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return len(self.seq_lens)
|
||||
|
||||
|
||||
def create_common_attn_metadata(
|
||||
batch_spec: BatchSpec,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
) -> CommonAttentionMetadata:
|
||||
query_start_loc = torch.zeros(
|
||||
batch_spec.batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
query_start_loc[1:] = torch.tensor(
|
||||
batch_spec.query_lens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
).cumsum(0)
|
||||
query_start_loc_cpu = query_start_loc.cpu()
|
||||
num_tokens = sum(batch_spec.query_lens)
|
||||
|
||||
seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
context_lens = [
|
||||
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
|
||||
for i in range(batch_spec.batch_size)
|
||||
]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||
block_table_tensor = torch.arange(
|
||||
batch_spec.batch_size * max_blocks,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
).view(batch_spec.batch_size, max_blocks)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=batch_spec.batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max(batch_spec.query_lens),
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_vllm_config(
|
||||
*,
|
||||
max_model_len: int = 8192,
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
num_heads: int = 32,
|
||||
num_speculative_tokens: int = 0,
|
||||
):
|
||||
speculative_config = None
|
||||
if num_speculative_tokens > 0:
|
||||
speculative_config = SimpleNamespace(
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
parallel_drafting=False,
|
||||
)
|
||||
|
||||
model_config = SimpleNamespace(max_model_len=max_model_len)
|
||||
model_config.get_num_attention_heads = lambda parallel_config: num_heads
|
||||
|
||||
return SimpleNamespace(
|
||||
cache_config=SimpleNamespace(mamba_cache_mode="none"),
|
||||
compilation_config=SimpleNamespace(
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
max_cudagraph_capture_size=None,
|
||||
),
|
||||
speculative_config=speculative_config,
|
||||
scheduler_config=SimpleNamespace(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
),
|
||||
parallel_config=SimpleNamespace(
|
||||
decode_context_parallel_size=1,
|
||||
tensor_parallel_size=1,
|
||||
),
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
|
||||
def _make_builder(*, device: torch.device, num_heads: int, num_speculative_tokens: int):
|
||||
vllm_config = _make_vllm_config(
|
||||
num_heads=num_heads,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
spec = MambaSpec(
|
||||
block_size=16,
|
||||
shapes=((1,), (1,)),
|
||||
dtypes=(torch.float32,),
|
||||
mamba_cache_mode="none",
|
||||
)
|
||||
return GDNAttentionMetadataBuilder(spec, ["layer0"], vllm_config, device)
|
||||
|
||||
|
||||
def _next_power_of_2(value: int) -> int:
|
||||
if value <= 1:
|
||||
return 1
|
||||
return 1 << (value - 1).bit_length()
|
||||
|
||||
|
||||
def _prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
pairs: list[list[int]] = []
|
||||
for seq_idx, seq_len in enumerate(lens):
|
||||
num_chunks = (seq_len + chunk_size - 1) // chunk_size
|
||||
for chunk_idx in range(num_chunks):
|
||||
pairs.append([seq_idx, chunk_idx])
|
||||
if not pairs:
|
||||
return torch.empty((0, 2), dtype=cu_seqlens.dtype, device=cu_seqlens.device)
|
||||
return torch.tensor(pairs, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
|
||||
|
||||
|
||||
def _prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
num_chunks = torch.div(
|
||||
lens + chunk_size - 1,
|
||||
chunk_size,
|
||||
rounding_mode="floor",
|
||||
)
|
||||
offsets = torch.zeros(len(num_chunks) + 1, dtype=cu_seqlens.dtype)
|
||||
torch.cumsum(num_chunks, dim=0, out=offsets[1:])
|
||||
return offsets.to(cu_seqlens.device)
|
||||
|
||||
|
||||
def _prepare_update_chunk_offsets(
|
||||
cu_seqlens: torch.Tensor, chunk_size: int
|
||||
) -> torch.Tensor:
|
||||
lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
num_chunks = torch.div(
|
||||
lens + chunk_size - 1,
|
||||
chunk_size,
|
||||
rounding_mode="floor",
|
||||
) + 1
|
||||
offsets = torch.zeros(len(num_chunks) + 1, dtype=cu_seqlens.dtype)
|
||||
torch.cumsum(num_chunks, dim=0, out=offsets[1:])
|
||||
return offsets.to(cu_seqlens.device)
|
||||
|
||||
|
||||
def _prepare_final_chunk_indices(
|
||||
cu_seqlens: torch.Tensor, chunk_size: int
|
||||
) -> torch.Tensor:
|
||||
lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
num_chunks = torch.div(
|
||||
lens + chunk_size - 1,
|
||||
chunk_size,
|
||||
rounding_mode="floor",
|
||||
) + 1
|
||||
return (torch.cumsum(num_chunks, dim=0) - 1).to(cu_seqlens.device)
|
||||
|
||||
|
||||
def _build_non_spec_query_start_loc_cpu(
|
||||
batch_spec: BatchSpec, spec_mask_cpu: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
query_lens = torch.tensor(batch_spec.query_lens, dtype=torch.int32)
|
||||
if spec_mask_cpu is not None:
|
||||
query_lens = query_lens[~spec_mask_cpu]
|
||||
query_start_loc = torch.zeros(query_lens.numel() + 1, dtype=torch.int32)
|
||||
torch.cumsum(query_lens, dim=0, out=query_start_loc[1:])
|
||||
return query_start_loc
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("batch_spec", "num_speculative_tokens", "num_decode_draft_tokens_cpu"),
|
||||
[
|
||||
(
|
||||
BatchSpec(
|
||||
seq_lens=[8, 12],
|
||||
query_lens=[4, 8],
|
||||
name="pure_non_spec_prefill",
|
||||
),
|
||||
0,
|
||||
None,
|
||||
),
|
||||
(
|
||||
BatchSpec(
|
||||
seq_lens=[8, 4, 0, 12],
|
||||
query_lens=[4, 4, 0, 8],
|
||||
name="mixed_spec_non_spec_with_padding",
|
||||
),
|
||||
3,
|
||||
torch.tensor([-1, 3, -1, -1], dtype=torch.int32),
|
||||
),
|
||||
(
|
||||
BatchSpec(
|
||||
seq_lens=[5, 12, 0, 9],
|
||||
query_lens=[1, 8, 0, 1],
|
||||
name="mixed_prefill_decode_without_spec",
|
||||
),
|
||||
0,
|
||||
None,
|
||||
),
|
||||
],
|
||||
ids=lambda case: case.name if isinstance(case, BatchSpec) else None,
|
||||
)
|
||||
def test_builder_prebuilds_non_spec_chunk_metadata_exactly(
|
||||
batch_spec: BatchSpec,
|
||||
num_speculative_tokens: int,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None,
|
||||
):
|
||||
device = torch.device("cpu")
|
||||
num_heads = 32
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec=batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
builder = _make_builder(
|
||||
device=device,
|
||||
num_heads=num_heads,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
num_accepted_tokens = None
|
||||
spec_mask_cpu = None
|
||||
if num_decode_draft_tokens_cpu is not None:
|
||||
num_accepted_tokens = torch.ones(
|
||||
batch_spec.batch_size, dtype=torch.int32, device=device
|
||||
)
|
||||
spec_mask_cpu = num_decode_draft_tokens_cpu >= 0
|
||||
|
||||
attn_metadata = builder.build(
|
||||
0,
|
||||
common_attn_metadata,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
|
||||
)
|
||||
|
||||
non_spec_query_start_loc_cpu = _build_non_spec_query_start_loc_cpu(
|
||||
batch_spec,
|
||||
spec_mask_cpu,
|
||||
)
|
||||
legacy_chunk_indices_64 = _prepare_chunk_indices(non_spec_query_start_loc_cpu, 64)
|
||||
legacy_chunk_offsets_64 = _prepare_chunk_offsets(non_spec_query_start_loc_cpu, 64)
|
||||
legacy_update_chunk_offsets_64 = _prepare_update_chunk_offsets(
|
||||
non_spec_query_start_loc_cpu,
|
||||
64,
|
||||
)
|
||||
legacy_final_chunk_indices_64 = _prepare_final_chunk_indices(
|
||||
non_spec_query_start_loc_cpu,
|
||||
64,
|
||||
)
|
||||
legacy_chunk_indices_large_block = _prepare_chunk_indices(
|
||||
non_spec_query_start_loc_cpu,
|
||||
patch_gdn_attn._GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE,
|
||||
)
|
||||
optim_block_size = _next_power_of_2(
|
||||
patch_gdn_attn._GDN_CUMSUM_WORKING_SET
|
||||
// (num_heads * patch_gdn_attn._GDN_CHUNK_SIZE)
|
||||
)
|
||||
legacy_block_indices_cumsum = _prepare_chunk_indices(
|
||||
non_spec_query_start_loc_cpu,
|
||||
optim_block_size,
|
||||
)
|
||||
|
||||
prebuilt_meta = getattr(attn_metadata, "non_spec_chunked_prefill_meta", None)
|
||||
assert prebuilt_meta is not None
|
||||
assert torch.equal(prebuilt_meta.chunk_indices_chunk64, legacy_chunk_indices_64)
|
||||
assert torch.equal(prebuilt_meta.chunk_offsets_chunk64, legacy_chunk_offsets_64)
|
||||
assert torch.equal(
|
||||
prebuilt_meta.update_chunk_offsets_chunk64, legacy_update_chunk_offsets_64
|
||||
)
|
||||
assert torch.equal(
|
||||
prebuilt_meta.final_chunk_indices_chunk64, legacy_final_chunk_indices_64
|
||||
)
|
||||
assert torch.equal(
|
||||
prebuilt_meta.chunk_indices_large_block,
|
||||
legacy_chunk_indices_large_block,
|
||||
)
|
||||
assert torch.equal(
|
||||
prebuilt_meta.block_indices_cumsum,
|
||||
legacy_block_indices_cumsum,
|
||||
)
|
||||
|
||||
|
||||
def test_allocate_chunked_prefill_slot_uses_cpugpubuffer(monkeypatch):
|
||||
class DummyCpuGpuBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
*size,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
with_numpy: bool = True,
|
||||
) -> None:
|
||||
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu")
|
||||
self.gpu = torch.zeros_like(self.cpu, device=device)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.with_numpy = with_numpy
|
||||
|
||||
device = torch.device("cpu")
|
||||
builder = _make_builder(
|
||||
device=device,
|
||||
num_heads=32,
|
||||
num_speculative_tokens=0,
|
||||
)
|
||||
monkeypatch.setattr(patch_gdn_attn, "CpuGpuBuffer", DummyCpuGpuBuffer)
|
||||
|
||||
slot = patch_gdn_attn._allocate_chunked_prefill_slot(builder, device)
|
||||
|
||||
assert isinstance(slot.chunk_indices_chunk64, DummyCpuGpuBuffer)
|
||||
assert isinstance(slot.chunk_offsets_chunk64, DummyCpuGpuBuffer)
|
||||
assert isinstance(slot.update_chunk_offsets_chunk64, DummyCpuGpuBuffer)
|
||||
assert isinstance(slot.final_chunk_indices_chunk64, DummyCpuGpuBuffer)
|
||||
assert slot.chunk_indices_chunk64.pin_memory is True
|
||||
assert slot.chunk_indices_chunk64.with_numpy is False
|
||||
assert slot.chunk_indices_chunk64.device == device
|
||||
assert slot.chunk_indices_chunk64.cpu.shape == (
|
||||
builder.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
2,
|
||||
)
|
||||
assert slot.chunk_indices_chunk64.gpu.shape == (
|
||||
builder.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec",
|
||||
[
|
||||
BatchSpec(seq_lens=[1, 1, 1], query_lens=[1, 1, 1], name="decode_only"),
|
||||
BatchSpec(seq_lens=[4, 4], query_lens=[4, 4], name="spec_only"),
|
||||
],
|
||||
)
|
||||
def test_builder_skips_prebuilt_meta_without_non_spec_prefill(batch_spec: BatchSpec):
|
||||
device = torch.device("cpu")
|
||||
builder = _make_builder(
|
||||
device=device,
|
||||
num_heads=32,
|
||||
num_speculative_tokens=3 if batch_spec.name == "spec_only" else 0,
|
||||
)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec=batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_accepted_tokens = None
|
||||
num_decode_draft_tokens_cpu = None
|
||||
if batch_spec.name == "spec_only":
|
||||
num_accepted_tokens = torch.ones(
|
||||
batch_spec.batch_size, dtype=torch.int32, device=device
|
||||
)
|
||||
num_decode_draft_tokens_cpu = torch.full(
|
||||
(batch_spec.batch_size,), 3, dtype=torch.int32
|
||||
)
|
||||
|
||||
attn_metadata = builder.build(
|
||||
0,
|
||||
common_attn_metadata,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
|
||||
)
|
||||
|
||||
assert getattr(attn_metadata, "non_spec_chunked_prefill_meta", None) is None
|
||||
@@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
prebuilt_meta=None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
num_decodes = 0
|
||||
@@ -47,10 +48,34 @@ def chunk_gated_delta_rule_fwd(
|
||||
if attn_metadata is not None:
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
block_indices_cumsum = None if prebuilt_meta is None else prebuilt_meta.block_indices_cumsum
|
||||
chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_chunk64
|
||||
chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_offsets_chunk64
|
||||
update_chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.update_chunk_offsets_chunk64
|
||||
final_chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.final_chunk_indices_chunk64
|
||||
chunk_indices_large_block = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_large_block
|
||||
g = chunk_local_cumsum(
|
||||
g,
|
||||
chunk_size=chunk_size,
|
||||
cu_seqlens=cu_seqlens,
|
||||
block_indices=block_indices_cumsum,
|
||||
)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
A = chunk_scaled_dot_kkt_fwd(
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
A = solve_tril(
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices_large_block=chunk_indices_large_block,
|
||||
chunk_indices_bt=chunk_indices_chunk64,
|
||||
output_dtype=k.dtype,
|
||||
)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -58,6 +83,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
@@ -67,6 +93,8 @@ def chunk_gated_delta_rule_fwd(
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
chunk_offsets=chunk_offsets_chunk64,
|
||||
)
|
||||
|
||||
if get_pcp_group().world_size > 1:
|
||||
@@ -76,10 +104,15 @@ def chunk_gated_delta_rule_fwd(
|
||||
u=u,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices_chunk64,
|
||||
chunk_offsets=chunk_offsets_chunk64,
|
||||
update_chunk_offsets=update_chunk_offsets_chunk64,
|
||||
num_decodes=num_decodes,
|
||||
)
|
||||
all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0)
|
||||
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
|
||||
final_chunk_indices = final_chunk_indices_chunk64
|
||||
if final_chunk_indices is None:
|
||||
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
|
||||
final_h_update = h_update[:, final_chunk_indices, :, :, :]
|
||||
all_final_h_update = get_pcp_group().all_gather(final_h_update, 0)
|
||||
|
||||
@@ -137,6 +170,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
prebuilt_meta=None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -152,6 +186,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
prebuilt_meta=prebuilt_meta,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
@@ -169,6 +204,7 @@ def chunk_gated_delta_rule(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
prebuilt_meta=None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
@@ -268,7 +304,17 @@ def chunk_gated_delta_rule(
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
prebuilt_meta,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
|
||||
@@ -186,6 +186,8 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
@@ -193,15 +195,18 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
if chunk_offsets is None:
|
||||
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
chunk_offsets,
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
|
||||
@@ -163,6 +163,9 @@ def chunk_gated_delta_rule_fwd_hupdate(
|
||||
g: torch.Tensor | None = None,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
update_chunk_offsets: torch.Tensor | None = None,
|
||||
num_decodes: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
@@ -171,20 +174,25 @@ def chunk_gated_delta_rule_fwd_hupdate(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
if chunk_offsets is None:
|
||||
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
chunk_offsets,
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h_update = k.new_empty(B, NT + N, H, K, K, dtype=torch.float32)
|
||||
update_indices = prepare_update_chunk_offsets(cu_seqlens, BT)[:-1]
|
||||
if cu_seqlens is not None and update_chunk_offsets is None:
|
||||
update_chunk_offsets = prepare_update_chunk_offsets(cu_seqlens, BT)
|
||||
update_indices = update_chunk_offsets[:-1]
|
||||
h_update[:, update_indices, :, :, :] = torch.eye(K, dtype=h_update.dtype, device=h_update.device)
|
||||
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
|
||||
@@ -85,6 +85,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@@ -115,13 +116,8 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
chunk_indices = chunk_indices.npu()
|
||||
cu_seqlens = cu_seqlens.npu()
|
||||
else:
|
||||
chunk_indices = None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ def chunk_local_cumsum_scalar(
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
block_indices: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.Tensor | None = torch.float,
|
||||
):
|
||||
@@ -90,7 +91,8 @@ def chunk_local_cumsum_scalar(
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
|
||||
OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and block_indices is None:
|
||||
block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE)
|
||||
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (num_blocks, B)
|
||||
@@ -132,6 +134,7 @@ def chunk_local_cumsum(
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
block_indices=kwargs.get("block_indices"),
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
@@ -330,6 +330,8 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices_large_block: torch.Tensor | None = None,
|
||||
chunk_indices_bt: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -355,7 +357,9 @@ def solve_tril(
|
||||
|
||||
LARGE_BLOCK_T = 608 * 2
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices_large_block is None:
|
||||
chunk_indices_large_block = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
|
||||
chunk_indices = chunk_indices_large_block
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
|
||||
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
@@ -376,7 +380,9 @@ def solve_tril(
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices_bt is None:
|
||||
chunk_indices_bt = prepare_chunk_indices(cu_seqlens, BT)
|
||||
chunk_indices = chunk_indices_bt
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
merge_fn[NT, B * H](
|
||||
|
||||
@@ -102,12 +102,14 @@ def recompute_w_u_fwd(
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
if cu_seqlens is not None and chunk_indices is None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
BK = 64
|
||||
|
||||
@@ -259,6 +259,20 @@
|
||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
||||
#
|
||||
# ** 7a. File: worker/patch_gdn_attn.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.attention.backends.gdn_attn.GDNAttentionMetadataBuilder.build`
|
||||
# Why:
|
||||
# Qwen3.5/Qwen3Next GDN prefill on NPU needs prebuilt varlen chunk metadata
|
||||
# to avoid forward-time host round-trips that break async scheduling.
|
||||
# How:
|
||||
# Monkey-patch the upstream builder in-place, keep upstream code untouched,
|
||||
# and attach prebuilt device metadata bundle onto the returned attention
|
||||
# metadata object for Ascend-specific consumers.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream exposes a backend hook for extending GDN
|
||||
# metadata or when the optimization is accepted upstream directly.
|
||||
#
|
||||
# ** 8. File: worker/patch_qwen3_next.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet.forward`
|
||||
|
||||
@@ -31,6 +31,7 @@ import vllm_ascend.patch.worker.patch_minimax_m2 # noqa
|
||||
import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa
|
||||
import vllm_ascend.patch.worker.patch_mamba_utils # noqa
|
||||
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.worker.patch_gdn_attn # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_5 # noqa
|
||||
|
||||
321
vllm_ascend/patch/worker/patch_gdn_attn.py
Normal file
321
vllm_ascend/patch/worker/patch_gdn_attn.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import vllm.v1.attention.backends.gdn_attn as gdn_attn
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
_GDN_CHUNK_SIZE = 64
|
||||
# Keep this aligned with solve_tril.LARGE_BLOCK_T in ops/triton/fla/solve_tril.py.
|
||||
_GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE = 608 * 2
|
||||
_GDN_CUMSUM_WORKING_SET = 2**18
|
||||
|
||||
_IS_PATCHED = False
|
||||
_ORIGINAL_BUILD = gdn_attn.GDNAttentionMetadataBuilder.build
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNChunkedPrefillMetadata:
|
||||
chunk_indices_chunk64: torch.Tensor
|
||||
chunk_offsets_chunk64: torch.Tensor
|
||||
update_chunk_offsets_chunk64: torch.Tensor
|
||||
final_chunk_indices_chunk64: torch.Tensor
|
||||
chunk_indices_large_block: torch.Tensor
|
||||
block_indices_cumsum: torch.Tensor
|
||||
_buffer_slot: object | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GDNChunkedPrefillBufferSlot:
|
||||
chunk_indices_chunk64: CpuGpuBuffer
|
||||
chunk_offsets_chunk64: CpuGpuBuffer
|
||||
update_chunk_offsets_chunk64: CpuGpuBuffer
|
||||
final_chunk_indices_chunk64: CpuGpuBuffer
|
||||
chunk_indices_large_block: CpuGpuBuffer
|
||||
block_indices_cumsum: CpuGpuBuffer
|
||||
|
||||
|
||||
def _next_power_of_2(value: int) -> int:
|
||||
if value <= 1:
|
||||
return 1
|
||||
return 1 << (value - 1).bit_length()
|
||||
|
||||
|
||||
def _prepare_chunk_counts_cpu(cu_seqlens_cpu: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
lens = cu_seqlens_cpu[1:] - cu_seqlens_cpu[:-1]
|
||||
return torch.div(lens + chunk_size - 1, chunk_size, rounding_mode="floor")
|
||||
|
||||
|
||||
def _fill_chunk_indices_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
|
||||
cursor = 0
|
||||
for seq_idx, num_chunks in enumerate(chunk_counts.tolist()):
|
||||
if num_chunks <= 0:
|
||||
continue
|
||||
out[cursor : cursor + num_chunks, 0].fill_(seq_idx)
|
||||
out[cursor : cursor + num_chunks, 1] = torch.arange(
|
||||
num_chunks,
|
||||
dtype=out.dtype,
|
||||
)
|
||||
cursor += num_chunks
|
||||
return cursor
|
||||
|
||||
|
||||
def _fill_chunk_offsets_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
|
||||
out[0] = 0
|
||||
if chunk_counts.numel() > 0:
|
||||
torch.cumsum(chunk_counts, dim=0, out=out[1 : chunk_counts.numel() + 1])
|
||||
return chunk_counts.numel() + 1
|
||||
|
||||
|
||||
def _fill_update_chunk_offsets_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
|
||||
out[0] = 0
|
||||
if chunk_counts.numel() > 0:
|
||||
torch.cumsum(
|
||||
chunk_counts + 1,
|
||||
dim=0,
|
||||
out=out[1 : chunk_counts.numel() + 1],
|
||||
)
|
||||
return chunk_counts.numel() + 1
|
||||
|
||||
|
||||
def _fill_final_chunk_indices_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int:
|
||||
if chunk_counts.numel() > 0:
|
||||
torch.cumsum(chunk_counts + 1, dim=0, out=out[: chunk_counts.numel()])
|
||||
out[: chunk_counts.numel()].sub_(1)
|
||||
return chunk_counts.numel()
|
||||
|
||||
|
||||
def _get_gdn_num_heads(builder) -> int:
|
||||
hf_text_config = getattr(builder.vllm_config.model_config, "hf_text_config", None)
|
||||
if hf_text_config is not None and hasattr(hf_text_config, "linear_num_value_heads"):
|
||||
return hf_text_config.linear_num_value_heads // builder.vllm_config.parallel_config.tensor_parallel_size
|
||||
return builder.vllm_config.model_config.get_num_attention_heads(builder.vllm_config.parallel_config)
|
||||
|
||||
|
||||
def _allocate_chunked_prefill_slot(builder, device: torch.device):
|
||||
max_num_batched_tokens = builder.vllm_config.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = builder.vllm_config.scheduler_config.max_num_seqs
|
||||
return _GDNChunkedPrefillBufferSlot(
|
||||
chunk_indices_chunk64=CpuGpuBuffer(
|
||||
max_num_batched_tokens,
|
||||
2,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=True,
|
||||
with_numpy=False,
|
||||
),
|
||||
chunk_offsets_chunk64=CpuGpuBuffer(
|
||||
max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=True,
|
||||
with_numpy=False,
|
||||
),
|
||||
update_chunk_offsets_chunk64=CpuGpuBuffer(
|
||||
max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=True,
|
||||
with_numpy=False,
|
||||
),
|
||||
final_chunk_indices_chunk64=CpuGpuBuffer(
|
||||
max_num_seqs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=True,
|
||||
with_numpy=False,
|
||||
),
|
||||
chunk_indices_large_block=CpuGpuBuffer(
|
||||
max_num_batched_tokens,
|
||||
2,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=True,
|
||||
with_numpy=False,
|
||||
),
|
||||
block_indices_cumsum=CpuGpuBuffer(
|
||||
max_num_batched_tokens,
|
||||
2,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=True,
|
||||
with_numpy=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _ensure_chunk_meta_state(builder, device: torch.device) -> None:
|
||||
if getattr(builder, "_ascend_gdn_chunk_meta_initialized", False):
|
||||
return
|
||||
builder._ascend_gdn_chunk_meta_initialized = True
|
||||
builder._ascend_gdn_chunk_meta_device = device
|
||||
builder._ascend_gdn_chunk_size = _GDN_CHUNK_SIZE
|
||||
builder._ascend_gdn_large_block_size = _GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE
|
||||
gdn_num_heads = _get_gdn_num_heads(builder)
|
||||
cumsum_chunks = max(1, _GDN_CUMSUM_WORKING_SET // (gdn_num_heads * builder._ascend_gdn_chunk_size))
|
||||
builder._ascend_gdn_cumsum_block_size = _next_power_of_2(cumsum_chunks)
|
||||
builder._ascend_gdn_chunked_prefill_pool_idx = -1
|
||||
builder._ascend_gdn_chunked_prefill_pool = []
|
||||
if device.type != "cpu":
|
||||
builder._ascend_gdn_chunked_prefill_pool = [
|
||||
_allocate_chunked_prefill_slot(builder, device),
|
||||
_allocate_chunked_prefill_slot(builder, device),
|
||||
]
|
||||
|
||||
|
||||
def _build_non_spec_query_start_loc_cpu(
|
||||
builder,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
if attn_metadata.num_prefills <= 0:
|
||||
return None
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
if (
|
||||
not getattr(builder, "use_spec_decode", False)
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0].sum().item() == 0
|
||||
):
|
||||
return query_start_loc_cpu
|
||||
|
||||
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
|
||||
if spec_sequence_masks_cpu.sum().item() == 0:
|
||||
return query_start_loc_cpu
|
||||
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu]
|
||||
non_spec_query_start_loc_cpu = torch.zeros(
|
||||
non_spec_query_lens_cpu.numel() + 1,
|
||||
dtype=query_start_loc_cpu.dtype,
|
||||
)
|
||||
torch.cumsum(
|
||||
non_spec_query_lens_cpu,
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc_cpu[1:],
|
||||
)
|
||||
return non_spec_query_start_loc_cpu
|
||||
|
||||
|
||||
def _build_non_spec_chunked_prefill_meta_cpu(builder, cu_seqlens_cpu: torch.Tensor) -> GDNChunkedPrefillMetadata:
|
||||
chunk_counts_chunk64 = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_chunk_size)
|
||||
chunk_counts_large = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_large_block_size)
|
||||
chunk_counts_cumsum = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_cumsum_block_size)
|
||||
num_seqs = chunk_counts_chunk64.numel()
|
||||
chunk_indices_chunk64 = torch.empty((int(chunk_counts_chunk64.sum().item()), 2), dtype=torch.int32)
|
||||
chunk_offsets_chunk64 = torch.empty((num_seqs + 1,), dtype=torch.int32)
|
||||
update_chunk_offsets_chunk64 = torch.empty((num_seqs + 1,), dtype=torch.int32)
|
||||
final_chunk_indices_chunk64 = torch.empty((num_seqs,), dtype=torch.int32)
|
||||
chunk_indices_large_block = torch.empty((int(chunk_counts_large.sum().item()), 2), dtype=torch.int32)
|
||||
block_indices_cumsum = torch.empty((int(chunk_counts_cumsum.sum().item()), 2), dtype=torch.int32)
|
||||
|
||||
_fill_chunk_indices_cpu(chunk_indices_chunk64, chunk_counts_chunk64)
|
||||
_fill_chunk_offsets_cpu(chunk_offsets_chunk64, chunk_counts_chunk64)
|
||||
_fill_update_chunk_offsets_cpu(update_chunk_offsets_chunk64, chunk_counts_chunk64)
|
||||
_fill_final_chunk_indices_cpu(final_chunk_indices_chunk64, chunk_counts_chunk64)
|
||||
_fill_chunk_indices_cpu(chunk_indices_large_block, chunk_counts_large)
|
||||
_fill_chunk_indices_cpu(block_indices_cumsum, chunk_counts_cumsum)
|
||||
|
||||
return GDNChunkedPrefillMetadata(
|
||||
chunk_indices_chunk64=chunk_indices_chunk64.to(builder._ascend_gdn_chunk_meta_device),
|
||||
chunk_offsets_chunk64=chunk_offsets_chunk64.to(builder._ascend_gdn_chunk_meta_device),
|
||||
update_chunk_offsets_chunk64=update_chunk_offsets_chunk64.to(builder._ascend_gdn_chunk_meta_device),
|
||||
final_chunk_indices_chunk64=final_chunk_indices_chunk64.to(builder._ascend_gdn_chunk_meta_device),
|
||||
chunk_indices_large_block=chunk_indices_large_block.to(builder._ascend_gdn_chunk_meta_device),
|
||||
block_indices_cumsum=block_indices_cumsum.to(builder._ascend_gdn_chunk_meta_device),
|
||||
)
|
||||
|
||||
|
||||
def _build_non_spec_chunked_prefill_meta(builder, cu_seqlens_cpu: torch.Tensor) -> GDNChunkedPrefillMetadata:
|
||||
device = builder._ascend_gdn_chunk_meta_device
|
||||
if device.type == "cpu":
|
||||
return _build_non_spec_chunked_prefill_meta_cpu(builder, cu_seqlens_cpu)
|
||||
|
||||
builder._ascend_gdn_chunked_prefill_pool_idx = (builder._ascend_gdn_chunked_prefill_pool_idx + 1) % len(
|
||||
builder._ascend_gdn_chunked_prefill_pool
|
||||
)
|
||||
slot = builder._ascend_gdn_chunked_prefill_pool[builder._ascend_gdn_chunked_prefill_pool_idx]
|
||||
chunk_counts_chunk64 = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_chunk_size)
|
||||
chunk_counts_large = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_large_block_size)
|
||||
chunk_counts_cumsum = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_cumsum_block_size)
|
||||
num_chunk_indices_chunk64 = _fill_chunk_indices_cpu(slot.chunk_indices_chunk64.cpu, chunk_counts_chunk64)
|
||||
num_chunk_offsets_chunk64 = _fill_chunk_offsets_cpu(slot.chunk_offsets_chunk64.cpu, chunk_counts_chunk64)
|
||||
num_update_chunk_offsets_chunk64 = _fill_update_chunk_offsets_cpu(
|
||||
slot.update_chunk_offsets_chunk64.cpu, chunk_counts_chunk64
|
||||
)
|
||||
num_final_chunk_indices_chunk64 = _fill_final_chunk_indices_cpu(
|
||||
slot.final_chunk_indices_chunk64.cpu, chunk_counts_chunk64
|
||||
)
|
||||
num_chunk_indices_large = _fill_chunk_indices_cpu(slot.chunk_indices_large_block.cpu, chunk_counts_large)
|
||||
num_block_indices_cumsum = _fill_chunk_indices_cpu(slot.block_indices_cumsum.cpu, chunk_counts_cumsum)
|
||||
|
||||
chunk_indices_chunk64 = slot.chunk_indices_chunk64.copy_to_gpu(num_chunk_indices_chunk64)
|
||||
chunk_offsets_chunk64 = slot.chunk_offsets_chunk64.copy_to_gpu(num_chunk_offsets_chunk64)
|
||||
update_chunk_offsets_chunk64 = slot.update_chunk_offsets_chunk64.copy_to_gpu(num_update_chunk_offsets_chunk64)
|
||||
final_chunk_indices_chunk64 = slot.final_chunk_indices_chunk64.copy_to_gpu(num_final_chunk_indices_chunk64)
|
||||
chunk_indices_large_block = slot.chunk_indices_large_block.copy_to_gpu(num_chunk_indices_large)
|
||||
block_indices_cumsum = slot.block_indices_cumsum.copy_to_gpu(num_block_indices_cumsum)
|
||||
return GDNChunkedPrefillMetadata(
|
||||
chunk_indices_chunk64=chunk_indices_chunk64,
|
||||
chunk_offsets_chunk64=chunk_offsets_chunk64,
|
||||
update_chunk_offsets_chunk64=update_chunk_offsets_chunk64,
|
||||
final_chunk_indices_chunk64=final_chunk_indices_chunk64,
|
||||
chunk_indices_large_block=chunk_indices_large_block,
|
||||
block_indices_cumsum=block_indices_cumsum,
|
||||
_buffer_slot=slot,
|
||||
)
|
||||
|
||||
|
||||
def _patched_build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
):
|
||||
attn_metadata = _ORIGINAL_BUILD(
|
||||
self,
|
||||
common_prefix_len,
|
||||
common_attn_metadata,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
|
||||
fast_build=fast_build,
|
||||
)
|
||||
attn_metadata.non_spec_chunked_prefill_meta = None
|
||||
if attn_metadata.num_prefills <= 0:
|
||||
return attn_metadata
|
||||
|
||||
_ensure_chunk_meta_state(self, common_attn_metadata.query_start_loc.device)
|
||||
non_spec_query_start_loc_cpu = _build_non_spec_query_start_loc_cpu(
|
||||
self,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
num_decode_draft_tokens_cpu,
|
||||
)
|
||||
assert non_spec_query_start_loc_cpu is not None
|
||||
attn_metadata.non_spec_chunked_prefill_meta = _build_non_spec_chunked_prefill_meta(
|
||||
self, non_spec_query_start_loc_cpu
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
if not _IS_PATCHED:
|
||||
gdn_attn.GDNChunkedPrefillMetadata = GDNChunkedPrefillMetadata
|
||||
gdn_attn.GDNAttentionMetadataBuilder.build = _patched_build
|
||||
_IS_PATCHED = True
|
||||
@@ -178,6 +178,11 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
non_spec_chunked_prefill_meta = getattr(
|
||||
attn_metadata,
|
||||
"non_spec_chunked_prefill_meta",
|
||||
None,
|
||||
)
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
@@ -190,6 +195,7 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
prebuilt_meta=non_spec_chunked_prefill_meta,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
@@ -237,6 +237,11 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous()
|
||||
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
non_spec_chunked_prefill_meta = getattr(
|
||||
attn_metadata,
|
||||
"non_spec_chunked_prefill_meta",
|
||||
None,
|
||||
)
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
@@ -249,6 +254,7 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
prebuilt_meta=non_spec_chunked_prefill_meta,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user