diff --git a/tests/ut/patch/worker/patch_common/test_patch_gdn_attn.py b/tests/ut/patch/worker/patch_common/test_patch_gdn_attn.py new file mode 100644 index 00000000..37a2a8c7 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_gdn_attn.py @@ -0,0 +1,389 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest +import torch + +import vllm_ascend.patch.worker.patch_gdn_attn as patch_gdn_attn +from vllm.config.compilation import CUDAGraphMode +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.kv_cache_interface import MambaSpec + + +@dataclass +class BatchSpec: + seq_lens: list[int] + query_lens: list[int] + name: str = "unnamed" + + @property + def batch_size(self): + return len(self.seq_lens) + + +def create_common_attn_metadata( + batch_spec: BatchSpec, + block_size: int, + device: torch.device, +) -> CommonAttentionMetadata: + query_start_loc = torch.zeros( + batch_spec.batch_size + 1, + dtype=torch.int32, + device=device, + ) + query_start_loc[1:] = torch.tensor( + batch_spec.query_lens, + dtype=torch.int32, + device=device, + ).cumsum(0) + query_start_loc_cpu = query_start_loc.cpu() + num_tokens = sum(batch_spec.query_lens) + + seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device) + seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) + context_lens = [ + batch_spec.seq_lens[i] - batch_spec.query_lens[i] + for i in range(batch_spec.batch_size) + ] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size + block_table_tensor = torch.arange( + batch_spec.batch_size * max_blocks, + dtype=torch.int32, + device=device, + ).view(batch_spec.batch_size, max_blocks) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=batch_spec.batch_size, + num_actual_tokens=num_tokens, + max_query_len=max(batch_spec.query_lens), + max_seq_len=max_seq_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + causal=True, + ) + + +def _make_vllm_config( + *, + max_model_len: int = 8192, + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + num_heads: int = 32, + num_speculative_tokens: int = 0, +): + speculative_config = None + if num_speculative_tokens > 0: + speculative_config = SimpleNamespace( + num_speculative_tokens=num_speculative_tokens, + parallel_drafting=False, + ) + + model_config = SimpleNamespace(max_model_len=max_model_len) + model_config.get_num_attention_heads = lambda parallel_config: num_heads + + return SimpleNamespace( + cache_config=SimpleNamespace(mamba_cache_mode="none"), + compilation_config=SimpleNamespace( + cudagraph_mode=CUDAGraphMode.NONE, + max_cudagraph_capture_size=None, + ), + speculative_config=speculative_config, + scheduler_config=SimpleNamespace( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + ), + parallel_config=SimpleNamespace( + decode_context_parallel_size=1, + tensor_parallel_size=1, + ), + model_config=model_config, + ) + + +def _make_builder(*, device: torch.device, num_heads: int, num_speculative_tokens: int): + vllm_config = _make_vllm_config( + num_heads=num_heads, + num_speculative_tokens=num_speculative_tokens, + ) + spec = MambaSpec( + block_size=16, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="none", + ) + return GDNAttentionMetadataBuilder(spec, ["layer0"], vllm_config, device) + + +def _next_power_of_2(value: int) -> int: + if value <= 1: + return 1 + return 1 << (value - 1).bit_length() + + +def _prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor: + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + pairs: list[list[int]] = [] + for seq_idx, seq_len in enumerate(lens): + num_chunks = (seq_len + chunk_size - 1) // chunk_size + for chunk_idx in range(num_chunks): + pairs.append([seq_idx, chunk_idx]) + if not pairs: + return torch.empty((0, 2), dtype=cu_seqlens.dtype, device=cu_seqlens.device) + return torch.tensor(pairs, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + + +def _prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor: + lens = cu_seqlens[1:] - cu_seqlens[:-1] + num_chunks = torch.div( + lens + chunk_size - 1, + chunk_size, + rounding_mode="floor", + ) + offsets = torch.zeros(len(num_chunks) + 1, dtype=cu_seqlens.dtype) + torch.cumsum(num_chunks, dim=0, out=offsets[1:]) + return offsets.to(cu_seqlens.device) + + +def _prepare_update_chunk_offsets( + cu_seqlens: torch.Tensor, chunk_size: int +) -> torch.Tensor: + lens = cu_seqlens[1:] - cu_seqlens[:-1] + num_chunks = torch.div( + lens + chunk_size - 1, + chunk_size, + rounding_mode="floor", + ) + 1 + offsets = torch.zeros(len(num_chunks) + 1, dtype=cu_seqlens.dtype) + torch.cumsum(num_chunks, dim=0, out=offsets[1:]) + return offsets.to(cu_seqlens.device) + + +def _prepare_final_chunk_indices( + cu_seqlens: torch.Tensor, chunk_size: int +) -> torch.Tensor: + lens = cu_seqlens[1:] - cu_seqlens[:-1] + num_chunks = torch.div( + lens + chunk_size - 1, + chunk_size, + rounding_mode="floor", + ) + 1 + return (torch.cumsum(num_chunks, dim=0) - 1).to(cu_seqlens.device) + + +def _build_non_spec_query_start_loc_cpu( + batch_spec: BatchSpec, spec_mask_cpu: torch.Tensor | None +) -> torch.Tensor: + query_lens = torch.tensor(batch_spec.query_lens, dtype=torch.int32) + if spec_mask_cpu is not None: + query_lens = query_lens[~spec_mask_cpu] + query_start_loc = torch.zeros(query_lens.numel() + 1, dtype=torch.int32) + torch.cumsum(query_lens, dim=0, out=query_start_loc[1:]) + return query_start_loc + + +@pytest.mark.parametrize( + ("batch_spec", "num_speculative_tokens", "num_decode_draft_tokens_cpu"), + [ + ( + BatchSpec( + seq_lens=[8, 12], + query_lens=[4, 8], + name="pure_non_spec_prefill", + ), + 0, + None, + ), + ( + BatchSpec( + seq_lens=[8, 4, 0, 12], + query_lens=[4, 4, 0, 8], + name="mixed_spec_non_spec_with_padding", + ), + 3, + torch.tensor([-1, 3, -1, -1], dtype=torch.int32), + ), + ( + BatchSpec( + seq_lens=[5, 12, 0, 9], + query_lens=[1, 8, 0, 1], + name="mixed_prefill_decode_without_spec", + ), + 0, + None, + ), + ], + ids=lambda case: case.name if isinstance(case, BatchSpec) else None, +) +def test_builder_prebuilds_non_spec_chunk_metadata_exactly( + batch_spec: BatchSpec, + num_speculative_tokens: int, + num_decode_draft_tokens_cpu: torch.Tensor | None, +): + device = torch.device("cpu") + num_heads = 32 + common_attn_metadata = create_common_attn_metadata( + batch_spec=batch_spec, + block_size=16, + device=device, + ) + builder = _make_builder( + device=device, + num_heads=num_heads, + num_speculative_tokens=num_speculative_tokens, + ) + + num_accepted_tokens = None + spec_mask_cpu = None + if num_decode_draft_tokens_cpu is not None: + num_accepted_tokens = torch.ones( + batch_spec.batch_size, dtype=torch.int32, device=device + ) + spec_mask_cpu = num_decode_draft_tokens_cpu >= 0 + + attn_metadata = builder.build( + 0, + common_attn_metadata, + num_accepted_tokens=num_accepted_tokens, + num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu, + ) + + non_spec_query_start_loc_cpu = _build_non_spec_query_start_loc_cpu( + batch_spec, + spec_mask_cpu, + ) + legacy_chunk_indices_64 = _prepare_chunk_indices(non_spec_query_start_loc_cpu, 64) + legacy_chunk_offsets_64 = _prepare_chunk_offsets(non_spec_query_start_loc_cpu, 64) + legacy_update_chunk_offsets_64 = _prepare_update_chunk_offsets( + non_spec_query_start_loc_cpu, + 64, + ) + legacy_final_chunk_indices_64 = _prepare_final_chunk_indices( + non_spec_query_start_loc_cpu, + 64, + ) + legacy_chunk_indices_large_block = _prepare_chunk_indices( + non_spec_query_start_loc_cpu, + patch_gdn_attn._GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE, + ) + optim_block_size = _next_power_of_2( + patch_gdn_attn._GDN_CUMSUM_WORKING_SET + // (num_heads * patch_gdn_attn._GDN_CHUNK_SIZE) + ) + legacy_block_indices_cumsum = _prepare_chunk_indices( + non_spec_query_start_loc_cpu, + optim_block_size, + ) + + prebuilt_meta = getattr(attn_metadata, "non_spec_chunked_prefill_meta", None) + assert prebuilt_meta is not None + assert torch.equal(prebuilt_meta.chunk_indices_chunk64, legacy_chunk_indices_64) + assert torch.equal(prebuilt_meta.chunk_offsets_chunk64, legacy_chunk_offsets_64) + assert torch.equal( + prebuilt_meta.update_chunk_offsets_chunk64, legacy_update_chunk_offsets_64 + ) + assert torch.equal( + prebuilt_meta.final_chunk_indices_chunk64, legacy_final_chunk_indices_64 + ) + assert torch.equal( + prebuilt_meta.chunk_indices_large_block, + legacy_chunk_indices_large_block, + ) + assert torch.equal( + prebuilt_meta.block_indices_cumsum, + legacy_block_indices_cumsum, + ) + + +def test_allocate_chunked_prefill_slot_uses_cpugpubuffer(monkeypatch): + class DummyCpuGpuBuffer: + def __init__( + self, + *size, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + with_numpy: bool = True, + ) -> None: + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu") + self.gpu = torch.zeros_like(self.cpu, device=device) + self.dtype = dtype + self.device = device + self.pin_memory = pin_memory + self.with_numpy = with_numpy + + device = torch.device("cpu") + builder = _make_builder( + device=device, + num_heads=32, + num_speculative_tokens=0, + ) + monkeypatch.setattr(patch_gdn_attn, "CpuGpuBuffer", DummyCpuGpuBuffer) + + slot = patch_gdn_attn._allocate_chunked_prefill_slot(builder, device) + + assert isinstance(slot.chunk_indices_chunk64, DummyCpuGpuBuffer) + assert isinstance(slot.chunk_offsets_chunk64, DummyCpuGpuBuffer) + assert isinstance(slot.update_chunk_offsets_chunk64, DummyCpuGpuBuffer) + assert isinstance(slot.final_chunk_indices_chunk64, DummyCpuGpuBuffer) + assert slot.chunk_indices_chunk64.pin_memory is True + assert slot.chunk_indices_chunk64.with_numpy is False + assert slot.chunk_indices_chunk64.device == device + assert slot.chunk_indices_chunk64.cpu.shape == ( + builder.vllm_config.scheduler_config.max_num_batched_tokens, + 2, + ) + assert slot.chunk_indices_chunk64.gpu.shape == ( + builder.vllm_config.scheduler_config.max_num_batched_tokens, + 2, + ) + + +@pytest.mark.parametrize( + "batch_spec", + [ + BatchSpec(seq_lens=[1, 1, 1], query_lens=[1, 1, 1], name="decode_only"), + BatchSpec(seq_lens=[4, 4], query_lens=[4, 4], name="spec_only"), + ], +) +def test_builder_skips_prebuilt_meta_without_non_spec_prefill(batch_spec: BatchSpec): + device = torch.device("cpu") + builder = _make_builder( + device=device, + num_heads=32, + num_speculative_tokens=3 if batch_spec.name == "spec_only" else 0, + ) + common_attn_metadata = create_common_attn_metadata( + batch_spec=batch_spec, + block_size=16, + device=device, + ) + + num_accepted_tokens = None + num_decode_draft_tokens_cpu = None + if batch_spec.name == "spec_only": + num_accepted_tokens = torch.ones( + batch_spec.batch_size, dtype=torch.int32, device=device + ) + num_decode_draft_tokens_cpu = torch.full( + (batch_spec.batch_size,), 3, dtype=torch.int32 + ) + + attn_metadata = builder.build( + 0, + common_attn_metadata, + num_accepted_tokens=num_accepted_tokens, + num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu, + ) + + assert getattr(attn_metadata, "non_spec_chunked_prefill_meta", None) is None diff --git a/vllm_ascend/ops/triton/fla/chunk.py b/vllm_ascend/ops/triton/fla/chunk.py index 125bc13e..8935ea4c 100644 --- a/vllm_ascend/ops/triton/fla/chunk.py +++ b/vllm_ascend/ops/triton/fla/chunk.py @@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, + prebuilt_meta=None, ): forward_context = get_forward_context() num_decodes = 0 @@ -47,10 +48,34 @@ def chunk_gated_delta_rule_fwd( if attn_metadata is not None: num_decodes = attn_metadata.num_decodes chunk_size = 64 - g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) + block_indices_cumsum = None if prebuilt_meta is None else prebuilt_meta.block_indices_cumsum + chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_chunk64 + chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.chunk_offsets_chunk64 + update_chunk_offsets_chunk64 = None if prebuilt_meta is None else prebuilt_meta.update_chunk_offsets_chunk64 + final_chunk_indices_chunk64 = None if prebuilt_meta is None else prebuilt_meta.final_chunk_indices_chunk64 + chunk_indices_large_block = None if prebuilt_meta is None else prebuilt_meta.chunk_indices_large_block + g = chunk_local_cumsum( + g, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + block_indices=block_indices_cumsum, + ) # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + A = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices_chunk64, + output_dtype=torch.float32, + ) + A = solve_tril( + A=A, + cu_seqlens=cu_seqlens, + chunk_indices_large_block=chunk_indices_large_block, + chunk_indices_bt=chunk_indices_chunk64, + output_dtype=k.dtype, + ) w, u = recompute_w_u_fwd( k=k, v=v, @@ -58,6 +83,7 @@ def chunk_gated_delta_rule_fwd( A=A, g_cumsum=g, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices_chunk64, ) h, v_new, final_state = chunk_gated_delta_rule_fwd_h( k=k, @@ -67,6 +93,8 @@ def chunk_gated_delta_rule_fwd( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices_chunk64, + chunk_offsets=chunk_offsets_chunk64, ) if get_pcp_group().world_size > 1: @@ -76,10 +104,15 @@ def chunk_gated_delta_rule_fwd( u=u, g=g, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices_chunk64, + chunk_offsets=chunk_offsets_chunk64, + update_chunk_offsets=update_chunk_offsets_chunk64, num_decodes=num_decodes, ) all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0) - final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size) + final_chunk_indices = final_chunk_indices_chunk64 + if final_chunk_indices is None: + final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size) final_h_update = h_update[:, final_chunk_indices, :, :, :] all_final_h_update = get_pcp_group().all_gather(final_h_update, 0) @@ -137,6 +170,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, + prebuilt_meta=None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: @@ -152,6 +186,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + prebuilt_meta=prebuilt_meta, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel @@ -169,6 +204,7 @@ def chunk_gated_delta_rule( initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: torch.LongTensor | None = None, + prebuilt_meta=None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, ): @@ -268,7 +304,17 @@ def chunk_gated_delta_rule( if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + prebuilt_meta, + use_qk_l2norm_in_kernel, ) if head_first: o = rearrange(o, "b t h ... -> b h t ...") diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py index 85eab41c..e305878e 100644 --- a/vllm_ascend/ops/triton/fla/chunk_delta_h.py +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -186,6 +186,8 @@ def chunk_gated_delta_rule_fwd_h( chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # This kernel is slightly different from fla to support Q/K with different head numbers. # In fla, Q/K always have the same head number, so Hg is always equal to H. @@ -193,15 +195,18 @@ def chunk_gated_delta_rule_fwd_h( H = u.shape[-2] BT = chunk_size - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is not None and chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) # N: the actual number of sequences in the batch with either equal or variable lengths if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: + if chunk_offsets is None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) N, NT, chunk_offsets = ( len(cu_seqlens) - 1, len(chunk_indices), - prepare_chunk_offsets(cu_seqlens, BT), + chunk_offsets, ) assert K <= 256, "current kernel does not support head dimension larger than 256." diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py b/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py index 7ab1ef78..5b91fd9b 100644 --- a/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py +++ b/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py @@ -163,6 +163,9 @@ def chunk_gated_delta_rule_fwd_hupdate( g: torch.Tensor | None = None, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, + update_chunk_offsets: torch.Tensor | None = None, num_decodes: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # This kernel is slightly different from fla to support Q/K with different head numbers. @@ -171,20 +174,25 @@ def chunk_gated_delta_rule_fwd_hupdate( H = u.shape[-2] BT = chunk_size - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is not None and chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) # N: the actual number of sequences in the batch with either equal or variable lengths if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: + if chunk_offsets is None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) N, NT, chunk_offsets = ( len(cu_seqlens) - 1, len(chunk_indices), - prepare_chunk_offsets(cu_seqlens, BT), + chunk_offsets, ) assert K <= 256, "current kernel does not support head dimension larger than 256." h_update = k.new_empty(B, NT + N, H, K, K, dtype=torch.float32) - update_indices = prepare_update_chunk_offsets(cu_seqlens, BT)[:-1] + if cu_seqlens is not None and update_chunk_offsets is None: + update_chunk_offsets = prepare_update_chunk_offsets(cu_seqlens, BT) + update_indices = update_chunk_offsets[:-1] h_update[:, update_indices, :, :, :] = torch.eye(K, dtype=h_update.dtype, device=h_update.device) g = g.transpose(1, 2).contiguous() diff --git a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py index 1ad1aead..7ceafe3c 100644 --- a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py +++ b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py @@ -85,6 +85,7 @@ def chunk_scaled_dot_kkt_fwd( beta: torch.Tensor, g_cumsum: torch.Tensor | None = None, cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.Tensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -115,13 +116,8 @@ def chunk_scaled_dot_kkt_fwd( H = beta.shape[-1] BT = chunk_size - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.cpu() - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - chunk_indices = chunk_indices.npu() - cu_seqlens = cu_seqlens.npu() - else: - chunk_indices = None + if cu_seqlens is not None and chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) diff --git a/vllm_ascend/ops/triton/fla/cumsum.py b/vllm_ascend/ops/triton/fla/cumsum.py index da7bf8c9..ea2fa143 100644 --- a/vllm_ascend/ops/triton/fla/cumsum.py +++ b/vllm_ascend/ops/triton/fla/cumsum.py @@ -81,6 +81,7 @@ def chunk_local_cumsum_scalar( reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor | None = None, + block_indices: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.Tensor | None = torch.float, ): @@ -90,7 +91,8 @@ def chunk_local_cumsum_scalar( B, T, H = g.shape assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size)) - block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None + if cu_seqlens is not None and block_indices is None: + block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (num_blocks, B) @@ -132,6 +134,7 @@ def chunk_local_cumsum( reverse=reverse, scale=scale, cu_seqlens=cu_seqlens, + block_indices=kwargs.get("block_indices"), head_first=head_first, output_dtype=output_dtype, ) diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index 493b182b..1783035d 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -330,6 +330,8 @@ def merge_16x16_to_64x64_inverse_kernel( def solve_tril( A: torch.Tensor, cu_seqlens: torch.Tensor | None = None, + chunk_indices_large_block: torch.Tensor | None = None, + chunk_indices_bt: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float, ) -> torch.Tensor: """ @@ -355,7 +357,9 @@ def solve_tril( LARGE_BLOCK_T = 608 * 2 - chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None + if cu_seqlens is not None and chunk_indices_large_block is None: + chunk_indices_large_block = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) + chunk_indices = chunk_indices_large_block NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T) solve_tril_16x16_kernel[NT, B * H]( @@ -376,7 +380,9 @@ def solve_tril( Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + if cu_seqlens is not None and chunk_indices_bt is None: + chunk_indices_bt = prepare_chunk_indices(cu_seqlens, BT) + chunk_indices = chunk_indices_bt NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) merge_fn[NT, B * H]( diff --git a/vllm_ascend/ops/triton/fla/wy_fast.py b/vllm_ascend/ops/triton/fla/wy_fast.py index d6e24075..12cdfcad 100644 --- a/vllm_ascend/ops/triton/fla/wy_fast.py +++ b/vllm_ascend/ops/triton/fla/wy_fast.py @@ -102,12 +102,14 @@ def recompute_w_u_fwd( g_cumsum: torch.Tensor, A: torch.Tensor, cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] BT = A.shape[-1] - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + if cu_seqlens is not None and chunk_indices is None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BK = 64 diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index ad6d1d9b..4b167db0 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -259,6 +259,20 @@ # Remove this patch when bool is supported in 'torch.argsort' func of npu. # Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable. # +# ** 7a. File: worker/patch_gdn_attn.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.attention.backends.gdn_attn.GDNAttentionMetadataBuilder.build` +# Why: +# Qwen3.5/Qwen3Next GDN prefill on NPU needs prebuilt varlen chunk metadata +# to avoid forward-time host round-trips that break async scheduling. +# How: +# Monkey-patch the upstream builder in-place, keep upstream code untouched, +# and attach prebuilt device metadata bundle onto the returned attention +# metadata object for Ascend-specific consumers. +# Future Plan: +# Remove this patch when upstream exposes a backend hook for extending GDN +# metadata or when the optimization is accepted upstream directly. +# # ** 8. File: worker/patch_qwen3_next.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet.forward` diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 11982294..48c7b4e6 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -31,6 +31,7 @@ import vllm_ascend.patch.worker.patch_minimax_m2 # noqa import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa import vllm_ascend.patch.worker.patch_mamba_utils # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa +import vllm_ascend.patch.worker.patch_gdn_attn # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa import vllm_ascend.patch.worker.patch_qwen3_5 # noqa diff --git a/vllm_ascend/patch/worker/patch_gdn_attn.py b/vllm_ascend/patch/worker/patch_gdn_attn.py new file mode 100644 index 00000000..14e134c3 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_gdn_attn.py @@ -0,0 +1,321 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +import vllm.v1.attention.backends.gdn_attn as gdn_attn +from vllm.v1.utils import CpuGpuBuffer + +_GDN_CHUNK_SIZE = 64 +# Keep this aligned with solve_tril.LARGE_BLOCK_T in ops/triton/fla/solve_tril.py. +_GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE = 608 * 2 +_GDN_CUMSUM_WORKING_SET = 2**18 + +_IS_PATCHED = False +_ORIGINAL_BUILD = gdn_attn.GDNAttentionMetadataBuilder.build + + +@dataclass +class GDNChunkedPrefillMetadata: + chunk_indices_chunk64: torch.Tensor + chunk_offsets_chunk64: torch.Tensor + update_chunk_offsets_chunk64: torch.Tensor + final_chunk_indices_chunk64: torch.Tensor + chunk_indices_large_block: torch.Tensor + block_indices_cumsum: torch.Tensor + _buffer_slot: object | None = None + + +@dataclass +class _GDNChunkedPrefillBufferSlot: + chunk_indices_chunk64: CpuGpuBuffer + chunk_offsets_chunk64: CpuGpuBuffer + update_chunk_offsets_chunk64: CpuGpuBuffer + final_chunk_indices_chunk64: CpuGpuBuffer + chunk_indices_large_block: CpuGpuBuffer + block_indices_cumsum: CpuGpuBuffer + + +def _next_power_of_2(value: int) -> int: + if value <= 1: + return 1 + return 1 << (value - 1).bit_length() + + +def _prepare_chunk_counts_cpu(cu_seqlens_cpu: torch.Tensor, chunk_size: int) -> torch.Tensor: + lens = cu_seqlens_cpu[1:] - cu_seqlens_cpu[:-1] + return torch.div(lens + chunk_size - 1, chunk_size, rounding_mode="floor") + + +def _fill_chunk_indices_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int: + cursor = 0 + for seq_idx, num_chunks in enumerate(chunk_counts.tolist()): + if num_chunks <= 0: + continue + out[cursor : cursor + num_chunks, 0].fill_(seq_idx) + out[cursor : cursor + num_chunks, 1] = torch.arange( + num_chunks, + dtype=out.dtype, + ) + cursor += num_chunks + return cursor + + +def _fill_chunk_offsets_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int: + out[0] = 0 + if chunk_counts.numel() > 0: + torch.cumsum(chunk_counts, dim=0, out=out[1 : chunk_counts.numel() + 1]) + return chunk_counts.numel() + 1 + + +def _fill_update_chunk_offsets_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int: + out[0] = 0 + if chunk_counts.numel() > 0: + torch.cumsum( + chunk_counts + 1, + dim=0, + out=out[1 : chunk_counts.numel() + 1], + ) + return chunk_counts.numel() + 1 + + +def _fill_final_chunk_indices_cpu(out: torch.Tensor, chunk_counts: torch.Tensor) -> int: + if chunk_counts.numel() > 0: + torch.cumsum(chunk_counts + 1, dim=0, out=out[: chunk_counts.numel()]) + out[: chunk_counts.numel()].sub_(1) + return chunk_counts.numel() + + +def _get_gdn_num_heads(builder) -> int: + hf_text_config = getattr(builder.vllm_config.model_config, "hf_text_config", None) + if hf_text_config is not None and hasattr(hf_text_config, "linear_num_value_heads"): + return hf_text_config.linear_num_value_heads // builder.vllm_config.parallel_config.tensor_parallel_size + return builder.vllm_config.model_config.get_num_attention_heads(builder.vllm_config.parallel_config) + + +def _allocate_chunked_prefill_slot(builder, device: torch.device): + max_num_batched_tokens = builder.vllm_config.scheduler_config.max_num_batched_tokens + max_num_seqs = builder.vllm_config.scheduler_config.max_num_seqs + return _GDNChunkedPrefillBufferSlot( + chunk_indices_chunk64=CpuGpuBuffer( + max_num_batched_tokens, + 2, + dtype=torch.int32, + device=device, + pin_memory=True, + with_numpy=False, + ), + chunk_offsets_chunk64=CpuGpuBuffer( + max_num_seqs + 1, + dtype=torch.int32, + device=device, + pin_memory=True, + with_numpy=False, + ), + update_chunk_offsets_chunk64=CpuGpuBuffer( + max_num_seqs + 1, + dtype=torch.int32, + device=device, + pin_memory=True, + with_numpy=False, + ), + final_chunk_indices_chunk64=CpuGpuBuffer( + max_num_seqs, + dtype=torch.int32, + device=device, + pin_memory=True, + with_numpy=False, + ), + chunk_indices_large_block=CpuGpuBuffer( + max_num_batched_tokens, + 2, + dtype=torch.int32, + device=device, + pin_memory=True, + with_numpy=False, + ), + block_indices_cumsum=CpuGpuBuffer( + max_num_batched_tokens, + 2, + dtype=torch.int32, + device=device, + pin_memory=True, + with_numpy=False, + ), + ) + + +def _ensure_chunk_meta_state(builder, device: torch.device) -> None: + if getattr(builder, "_ascend_gdn_chunk_meta_initialized", False): + return + builder._ascend_gdn_chunk_meta_initialized = True + builder._ascend_gdn_chunk_meta_device = device + builder._ascend_gdn_chunk_size = _GDN_CHUNK_SIZE + builder._ascend_gdn_large_block_size = _GDN_SOLVE_TRIL_LARGE_BLOCK_SIZE + gdn_num_heads = _get_gdn_num_heads(builder) + cumsum_chunks = max(1, _GDN_CUMSUM_WORKING_SET // (gdn_num_heads * builder._ascend_gdn_chunk_size)) + builder._ascend_gdn_cumsum_block_size = _next_power_of_2(cumsum_chunks) + builder._ascend_gdn_chunked_prefill_pool_idx = -1 + builder._ascend_gdn_chunked_prefill_pool = [] + if device.type != "cpu": + builder._ascend_gdn_chunked_prefill_pool = [ + _allocate_chunked_prefill_slot(builder, device), + _allocate_chunked_prefill_slot(builder, device), + ] + + +def _build_non_spec_query_start_loc_cpu( + builder, + attn_metadata, + common_attn_metadata, + num_decode_draft_tokens_cpu: torch.Tensor | None, +) -> torch.Tensor | None: + if attn_metadata.num_prefills <= 0: + return None + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if ( + not getattr(builder, "use_spec_decode", False) + or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0].sum().item() == 0 + ): + return query_start_loc_cpu + + spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0 + if spec_sequence_masks_cpu.sum().item() == 0: + return query_start_loc_cpu + + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu] + non_spec_query_start_loc_cpu = torch.zeros( + non_spec_query_lens_cpu.numel() + 1, + dtype=query_start_loc_cpu.dtype, + ) + torch.cumsum( + non_spec_query_lens_cpu, + dim=0, + out=non_spec_query_start_loc_cpu[1:], + ) + return non_spec_query_start_loc_cpu + + +def _build_non_spec_chunked_prefill_meta_cpu(builder, cu_seqlens_cpu: torch.Tensor) -> GDNChunkedPrefillMetadata: + chunk_counts_chunk64 = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_chunk_size) + chunk_counts_large = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_large_block_size) + chunk_counts_cumsum = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_cumsum_block_size) + num_seqs = chunk_counts_chunk64.numel() + chunk_indices_chunk64 = torch.empty((int(chunk_counts_chunk64.sum().item()), 2), dtype=torch.int32) + chunk_offsets_chunk64 = torch.empty((num_seqs + 1,), dtype=torch.int32) + update_chunk_offsets_chunk64 = torch.empty((num_seqs + 1,), dtype=torch.int32) + final_chunk_indices_chunk64 = torch.empty((num_seqs,), dtype=torch.int32) + chunk_indices_large_block = torch.empty((int(chunk_counts_large.sum().item()), 2), dtype=torch.int32) + block_indices_cumsum = torch.empty((int(chunk_counts_cumsum.sum().item()), 2), dtype=torch.int32) + + _fill_chunk_indices_cpu(chunk_indices_chunk64, chunk_counts_chunk64) + _fill_chunk_offsets_cpu(chunk_offsets_chunk64, chunk_counts_chunk64) + _fill_update_chunk_offsets_cpu(update_chunk_offsets_chunk64, chunk_counts_chunk64) + _fill_final_chunk_indices_cpu(final_chunk_indices_chunk64, chunk_counts_chunk64) + _fill_chunk_indices_cpu(chunk_indices_large_block, chunk_counts_large) + _fill_chunk_indices_cpu(block_indices_cumsum, chunk_counts_cumsum) + + return GDNChunkedPrefillMetadata( + chunk_indices_chunk64=chunk_indices_chunk64.to(builder._ascend_gdn_chunk_meta_device), + chunk_offsets_chunk64=chunk_offsets_chunk64.to(builder._ascend_gdn_chunk_meta_device), + update_chunk_offsets_chunk64=update_chunk_offsets_chunk64.to(builder._ascend_gdn_chunk_meta_device), + final_chunk_indices_chunk64=final_chunk_indices_chunk64.to(builder._ascend_gdn_chunk_meta_device), + chunk_indices_large_block=chunk_indices_large_block.to(builder._ascend_gdn_chunk_meta_device), + block_indices_cumsum=block_indices_cumsum.to(builder._ascend_gdn_chunk_meta_device), + ) + + +def _build_non_spec_chunked_prefill_meta(builder, cu_seqlens_cpu: torch.Tensor) -> GDNChunkedPrefillMetadata: + device = builder._ascend_gdn_chunk_meta_device + if device.type == "cpu": + return _build_non_spec_chunked_prefill_meta_cpu(builder, cu_seqlens_cpu) + + builder._ascend_gdn_chunked_prefill_pool_idx = (builder._ascend_gdn_chunked_prefill_pool_idx + 1) % len( + builder._ascend_gdn_chunked_prefill_pool + ) + slot = builder._ascend_gdn_chunked_prefill_pool[builder._ascend_gdn_chunked_prefill_pool_idx] + chunk_counts_chunk64 = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_chunk_size) + chunk_counts_large = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_large_block_size) + chunk_counts_cumsum = _prepare_chunk_counts_cpu(cu_seqlens_cpu, builder._ascend_gdn_cumsum_block_size) + num_chunk_indices_chunk64 = _fill_chunk_indices_cpu(slot.chunk_indices_chunk64.cpu, chunk_counts_chunk64) + num_chunk_offsets_chunk64 = _fill_chunk_offsets_cpu(slot.chunk_offsets_chunk64.cpu, chunk_counts_chunk64) + num_update_chunk_offsets_chunk64 = _fill_update_chunk_offsets_cpu( + slot.update_chunk_offsets_chunk64.cpu, chunk_counts_chunk64 + ) + num_final_chunk_indices_chunk64 = _fill_final_chunk_indices_cpu( + slot.final_chunk_indices_chunk64.cpu, chunk_counts_chunk64 + ) + num_chunk_indices_large = _fill_chunk_indices_cpu(slot.chunk_indices_large_block.cpu, chunk_counts_large) + num_block_indices_cumsum = _fill_chunk_indices_cpu(slot.block_indices_cumsum.cpu, chunk_counts_cumsum) + + chunk_indices_chunk64 = slot.chunk_indices_chunk64.copy_to_gpu(num_chunk_indices_chunk64) + chunk_offsets_chunk64 = slot.chunk_offsets_chunk64.copy_to_gpu(num_chunk_offsets_chunk64) + update_chunk_offsets_chunk64 = slot.update_chunk_offsets_chunk64.copy_to_gpu(num_update_chunk_offsets_chunk64) + final_chunk_indices_chunk64 = slot.final_chunk_indices_chunk64.copy_to_gpu(num_final_chunk_indices_chunk64) + chunk_indices_large_block = slot.chunk_indices_large_block.copy_to_gpu(num_chunk_indices_large) + block_indices_cumsum = slot.block_indices_cumsum.copy_to_gpu(num_block_indices_cumsum) + return GDNChunkedPrefillMetadata( + chunk_indices_chunk64=chunk_indices_chunk64, + chunk_offsets_chunk64=chunk_offsets_chunk64, + update_chunk_offsets_chunk64=update_chunk_offsets_chunk64, + final_chunk_indices_chunk64=final_chunk_indices_chunk64, + chunk_indices_large_block=chunk_indices_large_block, + block_indices_cumsum=block_indices_cumsum, + _buffer_slot=slot, + ) + + +def _patched_build( + self, + common_prefix_len: int, + common_attn_metadata, + num_accepted_tokens: torch.Tensor | None = None, + num_decode_draft_tokens_cpu: torch.Tensor | None = None, + fast_build: bool = False, +): + attn_metadata = _ORIGINAL_BUILD( + self, + common_prefix_len, + common_attn_metadata, + num_accepted_tokens=num_accepted_tokens, + num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu, + fast_build=fast_build, + ) + attn_metadata.non_spec_chunked_prefill_meta = None + if attn_metadata.num_prefills <= 0: + return attn_metadata + + _ensure_chunk_meta_state(self, common_attn_metadata.query_start_loc.device) + non_spec_query_start_loc_cpu = _build_non_spec_query_start_loc_cpu( + self, + attn_metadata, + common_attn_metadata, + num_decode_draft_tokens_cpu, + ) + assert non_spec_query_start_loc_cpu is not None + attn_metadata.non_spec_chunked_prefill_meta = _build_non_spec_chunked_prefill_meta( + self, non_spec_query_start_loc_cpu + ) + return attn_metadata + + +if not _IS_PATCHED: + gdn_attn.GDNChunkedPrefillMetadata = GDNChunkedPrefillMetadata + gdn_attn.GDNAttentionMetadataBuilder.build = _patched_build + _IS_PATCHED = True diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py index 536e4695..0aa9bce6 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -178,6 +178,11 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): if attn_metadata.num_prefills > 0: initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 + non_spec_chunked_prefill_meta = getattr( + attn_metadata, + "non_spec_chunked_prefill_meta", + None, + ) ( core_attn_out_non_spec, last_recurrent_state, @@ -190,6 +195,7 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): initial_state=initial_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, + prebuilt_meta=non_spec_chunked_prefill_meta, head_first=False, use_qk_l2norm_in_kernel=True, ) diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 29694aed..d458aefc 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -237,6 +237,11 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): initial_state = ssm_state[non_spec_state_indices_tensor].transpose(-1, -2).contiguous() initial_state[~has_initial_state, ...] = 0 + non_spec_chunked_prefill_meta = getattr( + attn_metadata, + "non_spec_chunked_prefill_meta", + None, + ) ( core_attn_out_non_spec, last_recurrent_state, @@ -249,6 +254,7 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): initial_state=initial_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, + prebuilt_meta=non_spec_chunked_prefill_meta, head_first=False, use_qk_l2norm_in_kernel=True, )