From 9d09488b4a5c64ca52987da6f1c0d159e7fe9dae Mon Sep 17 00:00:00 2001 From: Bai Yongbin <845473182@qq.com> Date: Sat, 28 Feb 2026 21:44:08 +0800 Subject: [PATCH] [Feat] support basic pcp&dcp for qwen3next (#6091) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR implements Context Parallelism (CP) support for the Qwen3-Next model, including PCP (Parallel Context Parallelism) and DCP (Dynamic/Data Context Parallelism). - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/f176443446f659dbab5315e056e605d8984fd976 --------- Signed-off-by: SunnyLee219 <3294305115@qq.com> Signed-off-by: Jingchun Gao Signed-off-by: 白永斌 Signed-off-by: Bai Yongbin <845473182@qq.com> Co-authored-by: SunnyLee219 <3294305115@qq.com> Co-authored-by: Jingchun Gao Co-authored-by: 白永斌 Co-authored-by: Mengqing Cao --- .../4-cards/long_sequence/test_basic.py | 32 +- tests/ut/attention/test_attention_cp.py | 18 +- vllm_ascend/ascend_forward_context.py | 2 + vllm_ascend/attention/attention_v1.py | 16 +- .../context_parallel/attention_cp.py | 119 +++++++- .../attention/context_parallel/common_cp.py | 5 + vllm_ascend/attention/utils.py | 15 + vllm_ascend/ops/fused_moe/prepare_finalize.py | 9 + vllm_ascend/ops/triton/fla/chunk.py | 57 +++- vllm_ascend/ops/triton/fla/chunk_delta_h.py | 4 + .../ops/triton/fla/chunk_delta_hupdate.py | 211 +++++++++++++ vllm_ascend/ops/triton/fla/chunk_o_update.py | 121 ++++++++ vllm_ascend/ops/triton/fla/utils.py | 9 + vllm_ascend/ops/triton/mamba/causal_conv1d.py | 30 +- vllm_ascend/worker/model_runner_v1.py | 58 +++- vllm_ascend/worker/pcp_utils.py | 281 +++++++++++++++--- 16 files changed, 906 insertions(+), 81 deletions(-) create mode 100644 vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py create mode 100644 vllm_ascend/ops/triton/fla/chunk_o_update.py diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py index db8a26fe..fa6e2633 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py @@ -44,16 +44,15 @@ def test_models_pcp_dcp_basic(): runner.model.generate(prompts, sampling_params) model = "vllm-ascend/Qwen3-30B-A3B-W8A8" - with VllmRunner( - model, - enforce_eager=True, - max_model_len=1024, - tensor_parallel_size=2, - prefill_context_parallel_size=2, - decode_context_parallel_size=1, - enable_expert_parallel=True, - block_size=128, - quantization="ascend", + with VllmRunner(model, + enforce_eager=True, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=1, + enable_expert_parallel=True, + block_size=128, + quantization="ascend", ) as runner: runner.model.generate(prompts, sampling_params) @@ -71,6 +70,19 @@ def test_models_pcp_dcp_basic(): ) as runner: runner.model.generate(prompts, sampling_params) + model = "Qwen/Qwen3-Next-80B-A3B-Instruct" + with VllmRunner(model, + enforce_eager=True, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=1, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + gpu_memory_utilization=0.8, + block_size=128) as runner: + runner.model.generate(prompts, sampling_params) + def test_models_pcp_dcp_full_graph(): prompts = [ diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 877f593b..158d951a 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -81,6 +81,9 @@ class TestAscendAttentionCPImpl(TestBase): [0]) attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor( [0]) + attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx = torch.tensor( + [0, 1]) + attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False output, attn_lse = self.impl._forward_prefill_cp( query, key, value, attn_metadata) @@ -257,12 +260,23 @@ class TestAscendAttentionCPImpl(TestBase): attn_metadata.prefill = MagicMock() attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor( [0, 3, 1, 2, 0, 0, 0, 0]) + attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False + attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla = 0 + attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx = torch.arange( + num_tokens * 3 * self.impl.pcp_size + ) + attn_metadata.prefill.pcp_metadata.pcp_unpad_mask = torch.tensor( + [True, False, True, True, True, True, True, True] + ) + query = torch.rand(num_tokens, num_heads, head_size) key = torch.randn(num_tokens, num_heads, head_size) value = torch.randn(num_tokens, num_heads, head_size) + output = torch.rand(num_tokens, num_heads * head_size) - key, value = self.impl.reshape_and_cache(key, value, kv_cache, - attn_metadata) + query, key, value, output = self.impl.reshape_and_cache( + query, key, value, kv_cache, attn_metadata, output + ) self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size) self.assertEqual(key.shape[1], num_heads) self.assertEqual(key.shape[2], head_size) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index e1384653..b3426451 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -43,6 +43,7 @@ def set_ascend_forward_context( model_instance: torch.nn.Module = None, is_draft_model=False, skip_compiled: bool = False, + max_tokens_across_pcp: int = 0, draft_attn_metadatas=None, ): """A context manager that stores the current forward context, @@ -139,6 +140,7 @@ def set_ascend_forward_context( max_tokens_across_dp = num_tokens forward_context.max_tokens_across_dp = max_tokens_across_dp + forward_context.max_tokens_across_pcp = max_tokens_across_pcp if num_tokens is not None: if num_actual_tokens is None: diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 8b1b5a15..c3cf61b9 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -892,10 +892,12 @@ class AscendAttentionBackendImpl(AttentionImpl): def reshape_and_cache( self, + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, + output: torch.Tensor, ): if len(kv_cache) > 1: if self.is_kv_producer: @@ -915,7 +917,7 @@ class AscendAttentionBackendImpl(AttentionImpl): ) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() - return key, value + return query, key, value, output def forward_impl( self, @@ -970,12 +972,20 @@ class AscendAttentionBackendImpl(AttentionImpl): num_tokens = query.shape[0] if attn_metadata is None: return output.fill_(0) + output_padded = None if key is not None and value is not None: - key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata) + output_padded = output + query, key, value, output_padded = self.reshape_and_cache( + query, key, value, kv_cache, attn_metadata, output + ) # pooling model branch if attn_metadata.model_runner_type == "pooling": attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output - output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output) + if output_padded is not None: + attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded) + else: + attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] return output diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 121743f9..c9da487b 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -20,6 +20,7 @@ from typing import ClassVar import numpy as np import torch import torch.distributed as dist +import torch.nn.functional as F import torch_npu from vllm.config import VllmConfig from vllm.distributed import ( @@ -209,7 +210,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): head_attn_nomask_seqlens=head_attn_nomask_seqlens, tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, + pcp_use_hybrid_attn=common_long_seq_metadata.pcp_use_hybrid_attn, + pcp_unpad_mask=common_long_seq_metadata.pcp_unpad_mask, pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, + pcp_fa_query_idx=common_long_seq_metadata.pcp_fa_query_idx, + pcp_padded_tokens_fla=common_long_seq_metadata.pcp_padded_tokens_fla, + pcp_enter_fa_restore_idx=common_long_seq_metadata.pcp_enter_fa_restore_idx, ) prefill_metadata = AscendMetadataForPrefill( @@ -469,6 +475,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx + if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn: + fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx + query = torch.index_select(query, 0, fa_query_idx) + q_head = torch.index_select(query, 0, q_head_idx) q_tail = torch.index_select(query, 0, q_tail_idx) k_head_nomask = torch.index_select(key, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None @@ -735,14 +745,18 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): def reshape_and_cache( self, + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, + output: torch.Tensor, ): + num_tokens = query.shape[0] num_decode_tokens = attn_metadata.num_decode_tokens has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 + output_padded = output if len(kv_cache) > 1: if self.is_kv_producer: @@ -762,14 +776,23 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): if has_prefill: if self.pcp_size > 1: - kv = torch.cat([key, value], dim=-1) - num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size - all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) - assert attn_metadata.prefill is not None - assert attn_metadata.prefill.pcp_metadata is not None - pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx - all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx) - key, value = all_kv.split([self.head_size, self.head_size], dim=-1) + assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None + if not attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn: + kv = torch.cat([key, value], dim=-1) + num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size + all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) + pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx + all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx) + key, value = all_kv.split([self.head_size, self.head_size], dim=-1) + else: + query, key, value = self._gather_and_restore_pcp_qkv(query, key, value, attn_metadata) + num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded + output_local_padded_tokens_fa = num_actual_tokens_pcp_padded // self.pcp_size - num_tokens + if output_local_padded_tokens_fa > 0: + output_padded = F.pad( + output, pad=(0, 0, 0, 0, 0, output_local_padded_tokens_fa), mode="constant", value=0 + ) + prefill_key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded] prefill_value = value[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded] slot_mapping = attn_metadata.slot_mapping[ @@ -784,7 +807,62 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): ) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() - return key, value + return query, key, value, output_padded + + def _gather_and_restore_pcp_qkv( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + ): + """ + Gathers QKV chunks from all GPUs in the PCP group and restores the original + sequence order for Context Parallelism (CP). + """ + num_tokens = query.shape[0] + num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded + assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None + pcp_padded_tokens_fla = attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla + num_tokens_pcp_padded_fla = num_tokens + pcp_padded_tokens_fla + + qkv_fla = torch.cat( + [query.reshape(num_tokens, -1), key.reshape(num_tokens, -1), value.reshape(num_tokens, -1)], + dim=-1, + ) + if pcp_padded_tokens_fla > 0: + qkv_fla = F.pad(qkv_fla, pad=(0, 0, 0, pcp_padded_tokens_fla), mode="constant", value=0) + all_qkv = get_pcp_group().all_gather(qkv_fla[:num_tokens_pcp_padded_fla].contiguous(), dim=0) + + # Restore the original sequence order using pre-computed indices + pcp_enter_fa_restore_idx = ( + attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx if attn_metadata.prefill.pcp_metadata else None + ) + actual_qkv = torch.index_select(all_qkv, 0, pcp_enter_fa_restore_idx) + qkv_fa_padding_workspace = query.new_empty( + (num_actual_tokens_pcp_padded, (self.num_heads + 2 * self.num_kv_heads) * self.head_size) + ) + + decode_offset = attn_metadata.num_decode_tokens * self.pcp_size + qkv_fa_padding_workspace[:decode_offset] = actual_qkv[:decode_offset] + + pcp_unpad_mask = attn_metadata.prefill.pcp_metadata.pcp_unpad_mask[attn_metadata.num_decodes * self.pcp_size :] + qkv_fa_padding_workspace[decode_offset:][pcp_unpad_mask] = actual_qkv[decode_offset:] + + q, k, v = qkv_fa_padding_workspace.split( + [ + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + self.num_kv_heads * self.head_size, + ], + dim=-1, + ) + + return ( + q.reshape(-1, self.num_heads, self.head_size), + k.reshape(-1, self.num_kv_heads, self.head_size), + v.reshape(-1, self.num_kv_heads, self.head_size), + ) def _gather_global_context_output(self, local_context_attn_output): if self.dcp_size > 1: @@ -831,8 +909,15 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens + pcp_use_hybrid_attn = False + if has_prefill: + assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None + pcp_use_hybrid_attn = attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn if has_decode: - decode_query = query[:num_decode_tokens] + if pcp_use_hybrid_attn: + decode_query = query[: num_decode_tokens * self.pcp_size : self.pcp_size].contiguous() + else: + decode_query = query[:num_decode_tokens].contiguous() output_decode = self._forward_decode_pcp_dcp(decode_query, attn_metadata) output[:num_decode_tokens] = output_decode if has_prefill: @@ -849,7 +934,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): # qkv init num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size - prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() + if pcp_use_hybrid_attn: + prefill_query = query[self.pcp_size * num_decode_tokens :] + else: + prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() key = key[self.pcp_size * num_decode_tokens :].contiguous() value = value[self.pcp_size * num_decode_tokens :].contiguous() @@ -914,5 +1002,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_output_prefill, attn_lse_prefill, context_output, context_lse, prefill_query, attn_metadata ) + if self.pcp_size > 1 and pcp_use_hybrid_attn: + # layer_idx != num_layers - 1 + assert attn_metadata.prefill.pcp_metadata is not None + pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx + attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0) + attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx) + fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0] + output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0) + output[num_decode_tokens : attn_output_prefill.shape[0] + num_decode_tokens] = attn_output_prefill return output diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 8e9517e2..6eb42040 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -25,7 +25,12 @@ class AscendPCPMetadata: head_attn_nomask_seqlens: torch.Tensor = None tail_attn_nomask_seqlens: torch.Tensor = None q_full_idx: torch.Tensor = None + pcp_use_hybrid_attn: bool = False + pcp_unpad_mask: torch.Tensor = None pcp_allgather_restore_idx: list[int] | None = None + pcp_fa_query_idx: torch.Tensor = None + pcp_padded_tokens_fla: int = 0 + pcp_enter_fa_restore_idx: torch.Tensor = None block_table_cp: torch.Tensor = None valid_block_ids: torch.Tensor = None prefill_q_cum_seqlens: torch.Tensor = None diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 8eb91ec5..34244e0f 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -101,6 +101,21 @@ class AscendPrefillContextParallelMetadata: # original max_query_len before pcp split max_query_len_pcp_full: int = 0 + # the following attributes are specifically used in hybrid-attn models. + pcp_use_hybrid_attn: bool = False + + pcp_unpad_mask: torch.Tensor = None + + # to get the right order of query in prefill per rank + pcp_fa_query_idx: torch.Tensor = None + + # restore the full sequence across all pcp ranks + # when entering from linear-attention to attention + pcp_enter_fa_restore_idx: torch.Tensor = None + + # the number of tokens padded in linear-attn per rank + pcp_padded_tokens_fla: int = 0 + @dataclass class AscendCommonAttentionMetadata(CommonAttentionMetadata): diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index cced7ae6..9fe9239e 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -377,6 +377,15 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits = self.moe_config.dp_group.all_gather(router_logits, 0) if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + forward_context = get_forward_context() + max_tokens_across_pcp = forward_context.max_tokens_across_pcp + + self.num_tokens_pcp = hidden_states.shape[0] + pad_size = max_tokens_across_pcp - self.num_tokens_pcp + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) + hidden_states = get_pcp_group().all_gather( hidden_states, dim=0, diff --git a/vllm_ascend/ops/triton/fla/chunk.py b/vllm_ascend/ops/triton/fla/chunk.py index 58b5cc72..125bc13e 100644 --- a/vllm_ascend/ops/triton/fla/chunk.py +++ b/vllm_ascend/ops/triton/fla/chunk.py @@ -12,15 +12,19 @@ import warnings import torch from einops import rearrange +from vllm.distributed import get_pcp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_delta_hupdate import chunk_gated_delta_rule_fwd_hupdate from .chunk_o import chunk_fwd_o +from .chunk_o_update import chunk_fwd_o_update from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from .cumsum import chunk_local_cumsum from .l2norm import l2norm_fwd from .solve_tril import solve_tril -from .utils import input_guard +from .utils import input_guard, prepare_final_chunk_indices from .wy_fast import recompute_w_u_fwd @@ -35,7 +39,15 @@ def chunk_gated_delta_rule_fwd( output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, ): - g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + forward_context = get_forward_context() + num_decodes = 0 + attn_metadata = forward_context.attn_metadata + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values()), None) + if attn_metadata is not None: + num_decodes = attn_metadata.num_decodes + chunk_size = 64 + g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) @@ -56,6 +68,45 @@ def chunk_gated_delta_rule_fwd( output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) + + if get_pcp_group().world_size > 1: + h_update = chunk_gated_delta_rule_fwd_hupdate( + k=k, + w=w, + u=u, + g=g, + cu_seqlens=cu_seqlens, + num_decodes=num_decodes, + ) + all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0) + final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size) + final_h_update = h_update[:, final_chunk_indices, :, :, :] + all_final_h_update = get_pcp_group().all_gather(final_h_update, 0) + + updated_state = final_state.new_empty(get_pcp_group().world_size, *final_state.shape) + updated_state[0, ...] = all_final_state[0] + for i in range(1, get_pcp_group().world_size): + updated_final_state = all_final_state[i] + torch.matmul( + all_final_h_update[i, ...], updated_state[i - 1, ...] + ) + updated_state[i, ...] = updated_final_state + + final_state = updated_state[-1, ...] + + if get_pcp_group().rank_in_group == 0: + updated_h_state = torch.zeros_like(final_state) + else: + updated_h_state = updated_state[get_pcp_group().rank_in_group - 1, ...] + + h = chunk_fwd_o_update( + q=q, + v=v_new, + h=h, + h_update=h_update, + updated_h_state=updated_h_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( q=q, k=k, @@ -65,6 +116,7 @@ def chunk_gated_delta_rule_fwd( scale=scale, cu_seqlens=cu_seqlens, ) + if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None elif SUPPRESS_LEVEL >= 3: @@ -90,7 +142,6 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) - g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( q=q, k=k, diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py index d08fe9aa..85eab41c 100644 --- a/vllm_ascend/ops/triton/fla/chunk_delta_h.py +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ht, cu_seqlens, chunk_offsets, + h_update, T, H: tl.constexpr, Hg: tl.constexpr, @@ -72,6 +73,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32) b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32) + # create b_hupd_bv1 and b_hupd_bv2 v_start1 = 0 v_start2 = 64 @@ -204,6 +206,7 @@ def chunk_gated_delta_rule_fwd_h( assert K <= 256, "current kernel does not support head dimension larger than 256." h = k.new_empty(B, NT, H, K, V) + h_update = k.new_empty(B, NT, H, K, K) final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None v_new = torch.empty_like(u) if save_new_value else None @@ -223,6 +226,7 @@ def chunk_gated_delta_rule_fwd_h( ht=final_state, cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, + h_update=h_update, T=T, H=H, Hg=Hg, diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py b/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py new file mode 100644 index 00000000..7ab1ef78 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_delta_hupdate.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_update_chunk_offsets, safe_exp + +_CONDITIONS = ("seq7168",) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_hupdate_blockdim64( + k, + w, + g, + cu_seqlens, + chunk_offsets, + h_update, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_nh = tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + T_max = 1 * T + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + stride_k = Hg * K + stride_w = H * K + + # create b_hupd_bv1 and b_hupd_bv2 + off_hupd_1_top = tl.arange(0, 64)[:, None] + off_hupd_2_top = tl.arange(0, 64)[None, :] + + # main recurrence + for i_t in range(NT): + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos + i_h * T_max + last_idx) + + offs_t = i_t * BT + tl.arange(0, BT) + mask_t = offs_t < T + g_ptr = g + bos + i_h * T_max + b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) + + b_g = safe_exp(b_g_last - b_g) + b_g_last = tl.exp(b_g_last) + + offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None] + w_base = w + bos * H * K + i_h * K + # get column-sliced w [BT, 64] + offs_w_upd1 = tl.arange(0, 64)[None, :] + mask_w_upd1 = (offs_t_wv < T) & (offs_w_upd1 < K) + ptr_w_upd1 = w_base + offs_t_wv * stride_w + offs_w_upd1 * 1 + b_w_upd1 = tl.load(ptr_w_upd1, mask=mask_w_upd1, other=0.0).to(tl.float32) + + offs_w_upd2 = 64 + tl.arange(0, 64)[None, :] + mask_w_upd2 = (offs_t_wv < T) & (offs_w_upd2 < K) + ptr_w_upd2 = w_base + offs_t_wv * stride_w + offs_w_upd2 * 1 + b_w_upd2 = tl.load(ptr_w_upd2, mask=mask_w_upd2, other=0.0).to(tl.float32) + + k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K + # get row-sliced k [64, T] + p_k_upd1 = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k_upd1 = tl.load(p_k_upd1, boundary_check=(0, 1)) + p_k_upd2 = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k_upd2 = tl.load(p_k_upd2, boundary_check=(0, 1)) + + if USE_G: + b_w_upd1 = b_w_upd1 * b_g[:, None] + b_w_upd2 = b_w_upd2 * b_g[:, None] + + # compute [64, BT] @ [BT, 64] + b_hupd_local_11 = (off_hupd_1_top == off_hupd_2_top).to(tl.float32) + b_hupd_local_22 = (off_hupd_1_top == off_hupd_2_top).to(tl.float32) + + # fp32 + if USE_G: + b_hupd_local_11 = b_hupd_local_11 * b_g_last + b_hupd_local_22 = b_hupd_local_22 * b_g_last + + b_hupd_local_11 -= tl.dot(b_k_upd1, b_w_upd1.to(b_k_upd1.dtype)) + b_hupd_local_22 -= tl.dot(b_k_upd2, b_w_upd2.to(b_k_upd2.dtype)) + b_hupd_local_12 = -tl.dot(b_k_upd1, b_w_upd2.to(b_k_upd1.dtype)).to(tl.float32) + b_hupd_local_21 = -tl.dot(b_k_upd2, b_w_upd1.to(b_k_upd2.dtype)).to(tl.float32) + + hupd_base = h_update + (boh + i_t + i_n) * H * K * K + i_h * K * K + p_hupd_11 = tl.make_block_ptr(hupd_base, (K, K), (K, 1), (0, 0), (64, 64), (1, 0)) + b_hupd_11 = tl.load(p_hupd_11, boundary_check=(1, 0)) + p_hupd_21 = tl.make_block_ptr(hupd_base, (K, K), (K, 1), (64, 0), (64, 64), (1, 0)) + b_hupd_21 = tl.load(p_hupd_21, boundary_check=(1, 0)) + p_hupd_12 = tl.make_block_ptr(hupd_base, (K, K), (K, 1), (0, 64), (64, 64), (1, 0)) + b_hupd_12 = tl.load(p_hupd_12, boundary_check=(1, 0)) + p_hupd_22 = tl.make_block_ptr(hupd_base, (K, K), (K, 1), (64, 64), (64, 64), (1, 0)) + b_hupd_22 = tl.load(p_hupd_22, boundary_check=(1, 0)) + + b_hupd11_new = tl.dot(b_hupd_local_11.to(b_hupd_11.dtype), b_hupd_11).to(tl.float32) + b_hupd11_new += tl.dot(b_hupd_local_12.to(b_hupd_21.dtype), b_hupd_21) + + b_hupd21_new = tl.dot(b_hupd_local_21.to(b_hupd_11.dtype), b_hupd_11).to(tl.float32) + b_hupd21_new += tl.dot(b_hupd_local_22.to(b_hupd_21.dtype), b_hupd_21) + + b_hupd12_new = tl.dot(b_hupd_local_11.to(b_hupd_12.dtype), b_hupd_12).to(tl.float32) + b_hupd12_new += tl.dot(b_hupd_local_12.to(b_hupd_22.dtype), b_hupd_22) + + b_hupd22_new = tl.dot(b_hupd_local_21.to(b_hupd_12.dtype), b_hupd_12).to(tl.float32) + b_hupd22_new += tl.dot(b_hupd_local_22.to(b_hupd_22.dtype), b_hupd_22) + + hupd_next = h_update + (boh + i_t + i_n + 1) * H * K * K + i_h * K * K + p_hupd_11 = tl.make_block_ptr(hupd_next, (K, K), (K, 1), (0, 0), (64, 64), (1, 0)) + tl.store(p_hupd_11, b_hupd11_new.to(p_hupd_11.dtype.element_ty), boundary_check=(0, 1)) + + p_hupd_21 = tl.make_block_ptr(hupd_next, (K, K), (K, 1), (64, 0), (64, 64), (1, 0)) + tl.store(p_hupd_21, b_hupd21_new.to(p_hupd_21.dtype.element_ty), boundary_check=(0, 1)) + + p_hupd_12 = tl.make_block_ptr(hupd_next, (K, K), (K, 1), (0, 64), (64, 64), (1, 0)) + tl.store(p_hupd_12, b_hupd12_new.to(p_hupd_12.dtype.element_ty), boundary_check=(0, 1)) + + p_hupd_22 = tl.make_block_ptr(hupd_next, (K, K), (K, 1), (64, 64), (64, 64), (1, 0)) + tl.store(p_hupd_22, b_hupd22_new.to(p_hupd_22.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_hupdate( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + cu_seqlens: torch.LongTensor | None = None, + num_decodes: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, _ = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h_update = k.new_empty(B, NT + N, H, K, K, dtype=torch.float32) + update_indices = prepare_update_chunk_offsets(cu_seqlens, BT)[:-1] + h_update[:, update_indices, :, :, :] = torch.eye(K, dtype=h_update.dtype, device=h_update.device) + + g = g.transpose(1, 2).contiguous() + + def grid(meta): + return (1, N * H) + + chunk_gated_delta_rule_fwd_kernel_hupdate_blockdim64[grid]( + k=k, + w=w, + g=g, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + h_update=h_update, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + num_warps=4, + num_stages=2, + ) + h_update[:, : num_decodes * 2, :, :, :] = torch.zeros((K, K), dtype=h_update.dtype, device=h_update.device) + return h_update diff --git a/vllm_ascend/ops/triton/fla/chunk_o_update.py b/vllm_ascend/ops/triton/fla/chunk_o_update.py new file mode 100644 index 00000000..9479bd75 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_o_update.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +# mypy: ignore-errors + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_offsets + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o_update( + h, + h_update, + updated_h_state, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H # splitting by the head of the req + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int64) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # offset calculation + updated_h_state += (i_n * H + i_h).to(tl.int64) * K * V + + for i_t in range(NT): + i_tg = boh + i_t + h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V + hupd_base = h_update + ((i_tg + i_n) * H + i_h).to(tl.int64) * K * K + + for i_k in range(tl.cdiv(K, BK)): + p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hupd = tl.make_block_ptr(hupd_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BK), (1, 0)) + p_updated_h_state = tl.make_block_ptr( + updated_h_state, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) + ) + + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BK] + b_hupd = tl.load(p_hupd, boundary_check=(0, 1)) + # [BK, BV] + b_updated_h_state = tl.load(p_updated_h_state, boundary_check=(0, 1)) + + b_h += tl.dot(b_hupd.to(tl.bfloat16), b_updated_h_state.to(tl.bfloat16)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o_update( + q: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + h_update: torch.Tensor, + updated_h_state: torch.Tensor, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = chunk_size + + if cu_seqlens is None: + N, chunk_offsets = B, None + else: + N, chunk_offsets = ( + len(cu_seqlens) - 1, + prepare_chunk_offsets(cu_seqlens, BT), + ) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_fwd_kernel_o_update[grid]( + h=h, + h_update=h_update, + updated_h_state=updated_h_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=128, + num_warps=4, + num_stages=2, + ) + return h diff --git a/vllm_ascend/ops/triton/fla/utils.py b/vllm_ascend/ops/triton/fla/utils.py index fa23c2af..680bb9ca 100644 --- a/vllm_ascend/ops/triton/fla/utils.py +++ b/vllm_ascend/ops/triton/fla/utils.py @@ -24,10 +24,19 @@ def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torc return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) +def prepare_final_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + 1 + return torch.cumsum(indices, 0) - 1 + + def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) +def prepare_update_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + 1]).cumsum(-1) + + def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ A decorator to make sure all input tensors are contiguous and set the device based on input tensors. diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index 4080d6c3..4e3680ef 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -13,6 +13,8 @@ import torch import torch.nn.functional as F import triton import triton.language as tl +from vllm.distributed import get_pcp_group +from vllm.forward_context import get_forward_context from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore @@ -96,6 +98,14 @@ def causal_conv1d_fn( indices 0 and 3 out: (batch, dim, seqlen) """ + forward_context = get_forward_context() + num_decodes = 0 + attn_metadata = forward_context.attn_metadata + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values()), None) + if attn_metadata is not None: + num_decodes = attn_metadata.num_decodes + if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") if x.stride(-1) != 1: @@ -108,6 +118,13 @@ def causal_conv1d_fn( seqlens = seqlens.tolist() splits = torch.split(x, seqlens, dim=-1) width = weight.shape[1] + last_width_prefill_x = extract_last_width(x, query_start_loc[num_decodes:], conv_states.shape[-1]) + + if get_pcp_group().world_size > 1: + all_last_width_prefill_x = get_pcp_group().all_gather(last_width_prefill_x.unsqueeze(0).contiguous(), 0) + pcp_rank = get_pcp_group().rank_in_group + if pcp_rank > 0: + conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[pcp_rank - 1, ...] for i in range(len(seqlens)): x_s = splits[i] @@ -121,14 +138,25 @@ def causal_conv1d_fn( activation=activation, return_final_states=True, final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0), - initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None, + initial_states=conv_states[cache_indices[i]][..., : (width - 1)], ) ) + + if get_pcp_group().world_size > 1: + conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[-1, ...] out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) out_ref_tensor = torch.cat(out_ref, dim=0) return out_ref_tensor +def extract_last_width(x, start_loc, width): + end_loc = start_loc[1:] + offsets = torch.arange(width, device=x.device) + indices = end_loc.unsqueeze(1) - width + offsets.unsqueeze(0) # (num_seqs, width) + + return x[:, indices].permute(1, 0, 2) + + @triton.jit def _causal_conv1d_update_kernel_npu_tiled( # Pointers diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ee5dd81b..ef80131f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -383,6 +383,7 @@ class NPUModelRunner(GPUModelRunner): self.intermediate_tensors: IntermediateTensors | None = None self.reorder_batch_threshold: int | None = None self.long_seq_metadata = None + self.query_lens: torch.Tensor | None = None self.cpu_slot_mapping = None @property @@ -543,10 +544,12 @@ class NPUModelRunner(GPUModelRunner): self, scheduler_output: "SchedulerOutput", num_scheduled_tokens: np.ndarray, - ) -> tuple[torch.Tensor, SpecDecodeMetadata | None]: + ) -> tuple[torch.Tensor, SpecDecodeMetadata | None, int]: """ :return: tuple[ - logits_indices, spec_decode_metadata, + logits_indices, + spec_decode_metadata, + total_num_scheduled_tokens, ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -610,11 +613,10 @@ class NPUModelRunner(GPUModelRunner): if self.pcp_size > 1: num_scheduled_tokens[:num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( - num_scheduled_tokens[:num_reqs], - self.arange_np, + num_scheduled_tokens[:num_reqs], self.arange_np ) # Re-update after PCP split sequences. - total_num_scheduled_tokens = sum(num_scheduled_tokens) + total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) positions_np = self.positions.np[:total_num_scheduled_tokens] @@ -623,7 +625,11 @@ class NPUModelRunner(GPUModelRunner): position_pcp[:total_num_scheduled_tokens], out=positions_np, ) - self.query_lens = torch.from_numpy(num_scheduled_tokens) + if self.pcp_size > 1 and self.pcp_manager.pcp_use_hybrid_attn: + assert self.pcp_manager.num_scheduled_tokens_padded is not None + self.query_lens = torch.from_numpy(self.pcp_manager.num_scheduled_tokens_padded) + else: + self.query_lens = torch.from_numpy(num_scheduled_tokens) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -702,6 +708,8 @@ class NPUModelRunner(GPUModelRunner): self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens self.seq_lens.copy_to_gpu() + # Fill unused with -1. Needed for reshape_and_cache in attention_cp + self.query_start_loc.gpu[num_reqs + 1 :].fill_(-1) self.seq_lens.gpu[num_reqs:].fill_(0) # Copy the tensors to the NPU. @@ -732,6 +740,7 @@ class NPUModelRunner(GPUModelRunner): num_tokens_np = np.array(num_tokens, dtype=np.int32) base_num_reqs = self.input_batch.num_reqs num_reqs = base_num_reqs + tokens_original = None if self.pcp_size > 1: # while pcp > 1, we need the original num_scheduled_tokens before split # to calculate discard_requests_mask @@ -758,7 +767,7 @@ class NPUModelRunner(GPUModelRunner): num_draft_tokens = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) if self.use_cp: - logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens) + logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens, num_reqs, tokens_original) logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True) else: logits_indices = self.query_start_loc.gpu[1 : num_reqs + 1] - 1 @@ -807,7 +816,11 @@ class NPUModelRunner(GPUModelRunner): max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) - return logits_indices, spec_decode_metadata + return ( + logits_indices, + spec_decode_metadata, + total_num_scheduled_tokens, + ) def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): @@ -1152,6 +1165,7 @@ class NPUModelRunner(GPUModelRunner): ( logits_indices, spec_decode_metadata, + total_num_scheduled_tokens, ) = self._prepare_inputs( scheduler_output, num_scheduled_tokens_np, @@ -1220,7 +1234,9 @@ class NPUModelRunner(GPUModelRunner): num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) (attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata( - num_tokens=num_tokens_unpadded, + num_tokens=num_tokens_unpadded + if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn) + else total_num_scheduled_tokens, num_tokens_padded=num_tokens_padded, num_reqs=num_reqs, num_reqs_padded=num_reqs_padded, @@ -1240,7 +1256,13 @@ class NPUModelRunner(GPUModelRunner): intermediate_tensors, model_kwargs, ec_connector_output, - ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) + ) = self._preprocess( + scheduler_output, + num_tokens_padded + if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn) + else total_num_scheduled_tokens, + intermediate_tensors, + ) if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() @@ -1287,6 +1309,7 @@ class NPUModelRunner(GPUModelRunner): batch_descriptor=batch_desc, num_actual_tokens=scheduler_output.total_num_scheduled_tokens, model_instance=self.model, + max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp, skip_compiled=has_encoder_input, ), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, @@ -1922,11 +1945,16 @@ class NPUModelRunner(GPUModelRunner): def _get_block_table_and_slot_mapping(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec - maybe_pcp_full_tokens = ( - num_tokens_padded - if self.pcp_size == 1 - else num_tokens * self.pcp_size - sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]) - ) + if self.pcp_size > 1: + total_num_pcp_pads = sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]) + if self.pcp_manager.pcp_use_hybrid_attn: + num_scheduled_tokens_padded = self.pcp_manager.num_scheduled_tokens_padded + assert num_scheduled_tokens_padded is not None + maybe_pcp_full_tokens = sum(num_scheduled_tokens_padded) * self.pcp_size - total_num_pcp_pads + else: + maybe_pcp_full_tokens = num_tokens * self.pcp_size - total_num_pcp_pads + else: + maybe_pcp_full_tokens = num_tokens_padded if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): blk_table_tensor = torch.zeros( (num_reqs_padded, 1), diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 3528970b..61fdf0c3 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING import numpy as np import torch +import torch.nn.functional as F from vllm.config import VllmConfig from vllm.v1.utils import CpuGpuBuffer @@ -110,6 +111,20 @@ class PCPManager: self.query_lens_pcp_full = CpuGpuBuffer( self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory ) + self.pcp_fa_query_idx = torch.zeros( + self.max_num_tokens + 2 * self.max_num_reqs, dtype=torch.int32, device=self.device + ) + self.pcp_enter_fa_restore_idx = torch.zeros( + self.max_num_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32, device=self.device + ) + self.pcp_use_hybrid_attn = self.vllm_config.model_config.hf_config.model_type == "qwen3_next" + + self.pcp_pads_logits_hybrid_attn = torch.ones(self.max_num_reqs, dtype=torch.int32) * (self.pcp_world_size - 1) + self.pcp_padded_tokens_fla = 0 + self.pcp_padded_tokens_length = 0 + self.num_scheduled_tokens_padded = None + self.max_num_tokens_across_pcp = 0 + self.pcp_tokens_padded = None def _get_cumsum_and_arange( self, @@ -184,9 +199,10 @@ class PCPManager: Tuple (pcp_tokens, pcp_positions): - pcp_tokens: number of tokens per request that this PCP rank will actually process (after splitting / replication). + For hybrid-attention model: number of unpadded tokens + per requests - pcp_positions: flattened positions for those tokens on this rank, used to build the positions buffer for the model. - Example: >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) @@ -219,9 +235,10 @@ class PCPManager: # cu_padded_tokens: cumulative sum of padded token counts, # pcp_padded_arange: per-request arange flattened for padded tokens. cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np) + self.pcp_padded_tokens_length = pcp_padded_arange.shape[0] # Build the mask that marks which positions in the padded allgather buffer # correspond to real (unpadded) tokens. - self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = pcp_padded_arange < np.repeat( + self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length] = pcp_padded_arange < np.repeat( num_scheduled_tokens, num_padded_scheduled_tokens ) unpad_mask_decode = self.pcp_unpad_mask_cpu[: self.num_decode_tokens * self.pcp_world_size] @@ -272,6 +289,9 @@ class PCPManager: return positions positions = get_current_rank_positions(0, self.pcp_world_rank) + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + # Decode tokens are duplicated only after AG. But their positions are # same without prefill context parallel. if self.num_decode_reqs > 0: @@ -279,35 +299,192 @@ class PCPManager: num_scheduled_tokens[: self.num_decode_reqs], arange_np )[1] - # Build the restore index used after allgather. - padded_pos_start_loc = np.roll(cu_padded_tokens, 1) - padded_pos_start_loc[0] = 0 - all_positions_lst = [ - get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size) - ] - all_positions = np.concatenate(all_positions_lst) - self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort() - self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + if self.pcp_use_hybrid_attn: + max_scheduled_prefill_tokens = 0 + self.pcp_padded_tokens_fla = 0 + if self.num_decode_reqs > 0: + num_padded_scheduled_tokens[: self.num_decode_reqs] = ( + num_padded_scheduled_tokens[: self.num_decode_reqs] // self.pcp_world_size + ) + self.total_pcp_padding_tokens_fla = 0 + # have prefills + if self.num_reqs - self.num_decode_reqs > 0: + prefill_tokens_tensor = torch.Tensor(num_scheduled_tokens[self.num_decode_tokens :]) + # [num_prefill_reqs, pcp_world_size, 1] [[3,2]] [[2,2,2,1],[2,1,1,1]] + num_prefill_tokens_allranks = ( + self._get_cp_local_seq_lens(prefill_tokens_tensor, self.pcp_world_size, 1, 1).long().numpy() + ) + # [3] [2] | [2,2] [2,1] [2,1] [1,1] + num_prefill_scheduled_tokens_linear = num_prefill_tokens_allranks[:, self.pcp_world_rank, 0] + num_padded_scheduled_tokens[self.num_decode_reqs :] = num_prefill_scheduled_tokens_linear + # [[3,5]] | [[0,0,0,0,0],[0,0,0,0,0]] + num_prefill_tokens_start_loc = np.zeros( + (self.num_reqs - self.num_decode_reqs, self.pcp_world_size + 1), dtype=np.int64 + ) + # [[0,3,5]] | [[0,2,4,6,7],[0,2,3,4,5]] + num_prefill_tokens_start_loc[:, 1:] = np.cumsum(num_prefill_tokens_allranks[..., 0], axis=-1) + # [0] [3] | [0,0] [2,2] [4,3] [6,4] [7,5] + num_prefill_tokens_cu_ranks = num_prefill_tokens_start_loc[:, self.pcp_world_rank] + # [0,1,2] [0,1] | [0,1,0,1] [0,1,0] [0,1,0] [0,0] + # -> [0,1,2] [3,4] | [0,1,0,1] [2,3,2] [4,5,3] [6,4] + _, positions_linear = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np) + positions_linear[self.num_decode_reqs :] = positions_linear[self.num_decode_reqs :] + np.repeat( + num_prefill_tokens_cu_ranks, num_prefill_scheduled_tokens_linear + ) - self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs] - self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum() - return ( - pcp_tokens[: self.num_reqs], - positions, - ) + max_scheduled_prefill_tokens = num_prefill_tokens_allranks[:, 0, 0].sum() + num_prefill_tokens = num_scheduled_tokens[self.num_decode_reqs :].sum() + self.total_pcp_padding_tokens_fla = ( + max_scheduled_prefill_tokens * self.pcp_world_size - num_prefill_tokens + ) + self.pcp_padded_tokens_fla += max_scheduled_prefill_tokens - num_prefill_scheduled_tokens_linear.sum() - def get_logits_indices(self, cu_num_tokens: np.ndarray): - return torch.from_numpy(cu_num_tokens) * self.pcp_world_size - self.num_pcp_pads_cpu_tensor[: self.num_reqs] - 1 + max_scheduled_tokens = max_scheduled_prefill_tokens + self.num_decode_tokens + enter_fa_prefill_restore_idx = None + if self.num_reqs - self.num_decode_reqs > 0: + # prefill reorder idx + # [[3,2]] [[2,2,2,1],[2,2,1,1],[1,1,1,1]] + num_prefill_tokens_allranks = num_prefill_tokens_allranks[..., 0] + # [0,1,2,0,1] [0,1,0,1,0,1,0,|0,1,0,1,0,0] + _, prefill_arange_allranks = self._get_cumsum_and_arange( + num_prefill_tokens_allranks.flatten(), arange_np + ) + # [0,1] [0,1,2,3,0,1,2,3] + _, prefill_rank_offset = self._get_cumsum_and_arange( + np.ones(self.num_reqs - self.num_decode_reqs, dtype=np.int64) * self.pcp_world_size, arange_np + ) + # [0,0,0,3,3] [0,M,2M,3M,0,M,2M,3M] -> [0,0,M,M,2M,2M,3M,0,0,M,M,2M,3M] + D + prefill_all_offset = ( + np.repeat(prefill_rank_offset * max_scheduled_tokens, num_prefill_tokens_allranks.flatten()) + + self.num_decode_tokens + ) + + # [0,0,0,0,|2,2,2,1,|4,4,3,2] -> [0,0,0,0,0,0,0,|2,2,2,2,2,1,|4,4,3,2] + # [[0,0]] -> [0,0,0,0,0] + prefill_local_start_local = np.zeros_like(num_prefill_tokens_allranks) + prefill_local_start_local[1:, :] = np.cumsum(num_prefill_tokens_allranks, axis=0)[:-1, :] + prefill_local_offset = np.repeat( + prefill_local_start_local.flatten(), num_prefill_tokens_allranks.flatten() + ) + prefill_all_offset = np.add(prefill_all_offset, prefill_local_offset) + # [0,1,2,3,4] [0,1,M,M+1,2M,2M+1,3M,0,1,M,M+1,2M,3M] + enter_fa_prefill_restore_idx = np.add(prefill_all_offset, prefill_arange_allranks) + else: + _, positions_linear = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np) + + # decode reorder idx + enter_fa_decode_restore_idx = None + if self.num_decode_reqs > 0: + # [0,1,2], [4,4,4] -> [0,0,0,0,1,1,1,1,2,2,2,2] + num_decode_pcp_size = np.ones(self.num_decode_reqs, dtype=np.int64) * self.pcp_world_size + decode_reqs_offset = np.repeat(np.arange(self.num_decode_reqs, dtype=np.int64), num_decode_pcp_size) + decode_ranks_offset = ( + self._get_cumsum_and_arange(num_decode_pcp_size, arange_np)[1] * max_scheduled_tokens + ) + enter_fa_decode_restore_idx = np.add(decode_reqs_offset, decode_ranks_offset) + + if enter_fa_decode_restore_idx is not None and enter_fa_prefill_restore_idx is not None: + pcp_enter_fa_restore_idx = torch.from_numpy( + np.concatenate([enter_fa_decode_restore_idx, enter_fa_prefill_restore_idx]) + ) + elif enter_fa_decode_restore_idx is not None: + pcp_enter_fa_restore_idx = torch.from_numpy(enter_fa_decode_restore_idx) + + elif enter_fa_prefill_restore_idx is not None: + pcp_enter_fa_restore_idx = torch.from_numpy(enter_fa_prefill_restore_idx) + self.pcp_enter_fa_restore_idx[: pcp_enter_fa_restore_idx.shape[0]].copy_( + pcp_enter_fa_restore_idx.long(), non_blocking=True + ) + + if self.num_reqs > self.num_decode_reqs: + all_positions_prefill = [ + get_current_rank_positions(padded_pos_start_loc, rank_i)[self.num_decode_tokens :] + - self.num_decode_tokens * self.pcp_world_size + for rank_i in range(self.pcp_world_size) + ] + all_positions_prefill_tensor = torch.from_numpy(np.concatenate(all_positions_prefill)) + all_enter_fla_restore_idx = all_positions_prefill_tensor.float().argsort() + unpad_mask_prefill = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length][ + self.num_decode_reqs * self.pcp_world_size : + ] + # [0] | [0,7] + ori_tokens_start_loc = np.roll(np.cumsum(num_scheduled_tokens[self.num_decode_tokens :]), 1) + ori_tokens_start_loc[0] = 0 + # [0,1,2] [3,4] | [0,1,7,8] [2,3,9] [4,5,10] [6,11] + enter_fla_scatter_idx = positions_linear[self.num_decode_reqs :] + np.repeat( + ori_tokens_start_loc, num_prefill_scheduled_tokens_linear + ) + enter_fla_restore_idx = torch.index_select( + all_enter_fla_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(enter_fla_scatter_idx) + ) + self.pcp_allgather_restore_idx.gpu[: enter_fla_restore_idx.shape[0]].copy_( + enter_fla_restore_idx.long(), non_blocking=True + ) + + positions_prefill = all_positions_prefill[self.pcp_world_rank] + pcp_fa_query_idx_tensor = torch.from_numpy(positions_prefill) + self.pcp_fa_query_idx[: pcp_fa_query_idx_tensor.shape[0]].copy_( + pcp_fa_query_idx_tensor.long(), non_blocking=True + ) + self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs] + self.total_num_sampled_tokens_pcp = num_scheduled_tokens[: self.num_reqs].sum() + self.max_num_tokens_across_pcp = max_scheduled_tokens + self.pcp_tokens_padded = pcp_tokens[: self.num_reqs] + self.num_scheduled_tokens_padded = np.array(self.pcp_tokens_padded, dtype=np.int32) + return num_padded_scheduled_tokens, positions_linear + else: + # Build the restore index used after allgather. + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort() + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + + self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs] + self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum() + return pcp_tokens[: self.num_reqs], positions + + def get_logits_indices( + self, + cu_num_tokens: np.ndarray, + num_reqs: int, + tokens_original: list[int] | None = None, + ): + if not self.pcp_use_hybrid_attn or tokens_original is None: + logits_indices = ( + torch.from_numpy(cu_num_tokens) * self.pcp_world_size + - self.num_pcp_pads_cpu_tensor[: self.num_reqs] + - 1 + ) + else: + tokens_original_tensor = torch.tensor(tokens_original, dtype=torch.int32) + num_prefill_reqs = (tokens_original_tensor > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + decode_pads = self.pcp_pads_logits_hybrid_attn[:num_decode_reqs] + pad_len = tokens_original_tensor.shape[0] - num_decode_reqs + tokens_logits = tokens_original_tensor + F.pad(decode_pads, (0, pad_len), value=0) + logits_indices = torch.cumsum(tokens_logits, dim=0) - 1 + return logits_indices def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor): # After pcp allgather and restore, there are padded tokens in kv, # so we need pad slotmapping for alignment. - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size] - + if self.pcp_use_hybrid_attn: + assert self.num_scheduled_tokens_padded is not None + num_tokens = self.num_scheduled_tokens_padded.sum() + pcp_padded_slot_mapping = ( + self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size] + if not self.pcp_use_hybrid_attn + else self.pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size] + ) cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[: num_tokens * self.pcp_world_size] pcp_padded_slot_mapping.fill_(-1) pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping - return pcp_padded_slot_mapping + if self.pcp_use_hybrid_attn: + return pcp_padded_slot_mapping.clone() + else: + return pcp_padded_slot_mapping def get_restore_hidden_states( self, @@ -317,16 +494,25 @@ class PCPManager: # ignores the padding from CUDA Graph. from vllm.distributed.parallel_state import get_pcp_group - hidden_states = get_pcp_group().all_gather( - hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size], - 0, - ) - restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] - return torch.index_select( - hidden_states, - 0, - restore_idx, - ) + if not self.pcp_use_hybrid_attn: + hidden_states = get_pcp_group().all_gather( + hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size], + 0, + ) + restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] + return torch.index_select( + hidden_states, + 0, + restore_idx, + ) + else: + if self.pcp_padded_tokens_fla > 0: + hidden_states = F.pad( + hidden_states, pad=(0, 0, 0, self.pcp_padded_tokens_fla), mode="constant", value=0 + ) + hidden_states = get_pcp_group().all_gather(hidden_states.contiguous(), dim=0) + restore_idx = self.pcp_enter_fa_restore_idx[: hidden_states.shape[0] - self.total_pcp_padding_tokens_fla] + return torch.index_select(hidden_states, 0, restore_idx) def generate_pcp_mtp_input( self, @@ -528,6 +714,15 @@ class PCPManager: ): from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata + if self.pcp_world_size > 1 and self.pcp_use_hybrid_attn: + assert self.num_scheduled_tokens_padded is not None + total_num_scheduled_tokens = self.num_scheduled_tokens_padded.sum() + query_lens_new = ( + self.query_lens_pcp_full.cpu[:num_reqs] + if self.pcp_world_size > 1 and self.speculative_config + else query_lens + ) + num_decodes = (query_lens_new <= self.decode_threshold).sum().item() num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None @@ -599,10 +794,13 @@ class PCPManager: if num_reqs_padded > num_reqs: pad_size = num_reqs_padded - num_reqs ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) - + pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length] long_seq_metadata = AscendPrefillContextParallelMetadata( + pcp_use_hybrid_attn=self.pcp_use_hybrid_attn, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(), + pcp_unpad_mask=torch.from_numpy(pcp_unpad_mask), + pcp_padded_tokens_fla=self.pcp_padded_tokens_fla, ) if ori_query_lens_cpu is not None: long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu @@ -703,9 +901,20 @@ class PCPManager: "head_attn_nomask_seqlens": head_attn_nomask_seqlens, "tail_attn_nomask_seqlens": tail_attn_nomask_seqlens, } - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ - :num_actual_tokens_pcp_padded - ] + if not self.pcp_use_hybrid_attn: + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ + :num_actual_tokens_pcp_padded + ] + else: + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ + : num_scheduled_tokens.sum() - num_decodes + ] + long_seq_metadata.pcp_fa_query_idx = self.pcp_fa_query_idx[ + : num_actual_tokens_pcp_padded // self.pcp_world_size - num_decodes + ] + long_seq_metadata.pcp_enter_fa_restore_idx = self.pcp_enter_fa_restore_idx[ + : pcp_unpad_mask.sum() + num_decodes * (self.pcp_world_size - 1) + ] long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_full_idx = self.q_full_idx