diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py index 66dbbc0b..40f5f700 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py @@ -80,6 +80,7 @@ def test_models_pcp_dcp_basic(): decode_context_parallel_size=1, max_num_batched_tokens=1024, enable_expert_parallel=True, + long_prefill_token_threshold=4, gpu_memory_utilization=0.8, block_size=128) as runner: runner.model.generate(prompts, sampling_params) diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 158d951a..c3866b13 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -169,8 +169,6 @@ class TestAscendAttentionCPImpl(TestBase): attn_metadata.prefill.chunked_context = MagicMock() local_context_lens_allranks = torch.tensor([[[256, 256], [256, 256]]]) attn_metadata.prefill.chunked_context.local_context_lens_allranks = local_context_lens_allranks - attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint( - 0, 2, (1024, ), dtype=torch.bool) attn_metadata.prefill.chunked_context.local_total_toks = local_context_lens_allranks[:, 0, 0].sum( diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 14e0895a..eadd4fd7 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -141,12 +141,14 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None chunked_context_metadata = None + attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens if num_prefills > 0: query_lens = query_lens[num_decode_tokens:] context_lens_cpu = num_computed_tokens_cpu[num_decodes:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() - pcp_size = get_pcp_group().world_size if self.chunked_prefill_enabled and max_context_len_cpu > 0: + if self.pcp_size > 1 and common_long_seq_metadata.pcp_use_hybrid_attn: + query_lens = attn_mask_seqlens[0] * 2 local_context_lens_allranks = ( torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :] .to(self.device) @@ -163,7 +165,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): # when only using dcp. if self.pcp_size > 1: kv_inverse_idx_for_chunk = torch.argsort( - common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to( + common_long_seq_metadata.pcp_allgather_restore_idx[self.pcp_size * num_decode_tokens :].to( torch.float32 ) ) @@ -172,29 +174,23 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): kv_inverse_idx_for_chunk = None cp_kv_recover_idx_for_chunk = None - batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0 - batch_chunk_seq_mask = torch.repeat_interleave( - batch_chunk_seq_mask, repeats=(query_lens * self.pcp_size).to(self.device) - ) chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to( self.device ) chunked_context_metadata = AscendMetadataForPrefill.ChunkedContextMetadata( - actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), + actual_chunk_seq_lengths=torch.cumsum(query_lens * self.pcp_size, dim=0), actual_seq_lengths_kv=actual_seq_lengths_kv, chunked_req_mask=chunked_req_mask, starts=local_chunk_starts, local_context_lens_allranks=local_context_lens_allranks, cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, - batch_chunk_seq_mask=batch_chunk_seq_mask, chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices, local_total_toks=local_total_toks.item(), ) - attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens - if pcp_size > 1: + if self.pcp_size > 1: attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist() head_attn_nomask_seqlens = torch.cumsum(head_attn_nomask_seqlens[1], dim=0).tolist() tail_attn_nomask_seqlens = torch.cumsum(tail_attn_nomask_seqlens[1], dim=0).tolist() @@ -220,6 +216,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): prefill_metadata = AscendMetadataForPrefill( pcp_metadata=pcp_metadata, + pcp_exit_fa_scatter_idx=common_long_seq_metadata.pcp_exit_fa_scatter_idx, chunked_context=chunked_context_metadata, block_tables=block_table[self.num_decodes_flatten :, ...], actual_seq_lengths_q=torch.cumsum(query_lens, dim=0), @@ -475,9 +472,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx - if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn: - fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx - query = torch.index_select(query, 0, fa_query_idx) q_head = torch.index_select(query, 0, q_head_idx) q_tail = torch.index_select(query, 0, q_tail_idx) @@ -541,7 +535,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): assert self.value_cache is not None if self.dcp_size > 1: - query = get_dcp_group().all_gather(query, 1) + query = get_dcp_group().all_gather(query.contiguous(), 1) num_heads = self.num_heads * self.dcp_size else: num_heads = self.num_heads @@ -936,6 +930,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size if pcp_use_hybrid_attn: prefill_query = query[self.pcp_size * num_decode_tokens :] + assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None + fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx + prefill_query = torch.index_select(prefill_query, 0, fa_query_idx) else: prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded].contiguous() @@ -993,7 +990,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_metadata, ) - if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: + if has_chunked_context: # update the output of current chunk with context part torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream()) global_context_output = global_context_output.permute([2, 0, 1]).contiguous() @@ -1005,9 +1002,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): if self.pcp_size > 1 and pcp_use_hybrid_attn: # layer_idx != num_layers - 1 assert attn_metadata.prefill.pcp_metadata is not None - pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx + pcp_exit_fa_scatter_idx = attn_metadata.prefill.pcp_exit_fa_scatter_idx attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0) - attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx) + attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_exit_fa_scatter_idx) fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0] output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0) diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index e06b0e11..c319698f 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -78,11 +78,11 @@ class AscendMetadataForPrefill: local_context_lens_allranks: list[list[int]] | None = None cp_kv_recover_idx_for_chunk: list[int] | None = None kv_inverse_idx_for_chunk: list[int] | None = None - batch_chunk_seq_mask: list[bool] | None = None local_total_toks: int | None = None """ Prefill Specific Metadata for Ascend""" pcp_metadata: AscendPCPMetadata | None = None + pcp_exit_fa_scatter_idx: torch.Tensor | None = None chunked_context: ChunkedContextMetadata | None = None block_tables: torch.Tensor = None actual_seq_lengths_q: torch.Tensor = None diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 34244e0f..946d5c66 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -113,6 +113,10 @@ class AscendPrefillContextParallelMetadata: # when entering from linear-attention to attention pcp_enter_fa_restore_idx: torch.Tensor = None + # scatter the full sequence across all pcp ranks + # when exiting from attention to linear-attention + pcp_exit_fa_scatter_idx: torch.Tensor = None + # the number of tokens padded in linear-attn per rank pcp_padded_tokens_fla: int = 0 diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index b46fc52a..e1cf42c9 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -75,6 +75,12 @@ class PCPManager: device=device, pin_memory=pin_memory, ) + self.pcp_exit_fa_scatter_idx = CpuGpuBuffer( + max_buffer_num_tokens, + dtype=torch.int64, + device=device, + pin_memory=pin_memory, + ) self.pcp_padded_slot_mapping = torch.full( (max_buffer_num_tokens,), fill_value=-1, @@ -110,9 +116,9 @@ class PCPManager: self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory ) self.positions_pcp_full_np = self.positions_pcp_full.numpy() - self.query_lens_pcp_full = CpuGpuBuffer( - self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory - ) + self.query_lens_pcp_full = CpuGpuBuffer( + self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory + ) self.pcp_fa_query_idx = torch.zeros( self.max_num_tokens + 2 * self.max_num_reqs, dtype=torch.int32, device=self.device ) @@ -164,6 +170,10 @@ class PCPManager: self.num_prefill_reqs = num_reqs - self.num_decode_reqs self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum() + self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens) + self.query_lens_pcp_full.cpu[self.num_reqs :].fill_(0) + self.query_lens_pcp_full.copy_to_gpu() + def update_tokens_for_pcp( self, num_scheduled_tokens: np.ndarray, @@ -301,6 +311,17 @@ class PCPManager: num_scheduled_tokens[: self.num_decode_reqs], arange_np )[1] + # Build the restore index used after allgather. + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort() + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + + self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs] + self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum() + if self.pcp_use_hybrid_attn: max_scheduled_prefill_tokens = 0 self.pcp_padded_tokens_fla = 0 @@ -405,7 +426,7 @@ class PCPManager: for rank_i in range(self.pcp_world_size) ] all_positions_prefill_tensor = torch.from_numpy(np.concatenate(all_positions_prefill)) - all_enter_fla_restore_idx = all_positions_prefill_tensor.float().argsort() + all_exit_fa_restore_idx = all_positions_prefill_tensor.float().argsort() unpad_mask_prefill = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length][ self.num_decode_reqs * self.pcp_world_size : ] @@ -413,14 +434,15 @@ class PCPManager: ori_tokens_start_loc = np.roll(np.cumsum(num_scheduled_tokens[self.num_decode_tokens :]), 1) ori_tokens_start_loc[0] = 0 # [0,1,2] [3,4] | [0,1,7,8] [2,3,9] [4,5,10] [6,11] - enter_fla_scatter_idx = positions_linear[self.num_decode_reqs :] + np.repeat( + exit_fa_scatter_indices = positions_linear[self.num_decode_reqs :] + np.repeat( ori_tokens_start_loc, num_prefill_scheduled_tokens_linear ) - enter_fla_restore_idx = torch.index_select( - all_enter_fla_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(enter_fla_scatter_idx) + + exit_fa_scatter_idx = torch.index_select( + all_exit_fa_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(exit_fa_scatter_indices) ) - self.pcp_allgather_restore_idx.gpu[: enter_fla_restore_idx.shape[0]].copy_( - enter_fla_restore_idx.long(), non_blocking=True + self.pcp_exit_fa_scatter_idx.gpu[: exit_fa_scatter_idx.shape[0]].copy_( + exit_fa_scatter_idx.long(), non_blocking=True ) positions_prefill = all_positions_prefill[self.pcp_world_rank] @@ -434,18 +456,7 @@ class PCPManager: self.pcp_tokens_padded = pcp_tokens[: self.num_reqs] self.num_scheduled_tokens_padded = np.array(self.pcp_tokens_padded, dtype=np.int32) return num_padded_scheduled_tokens, positions_linear - else: - # Build the restore index used after allgather. - all_positions_lst = [ - get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size) - ] - all_positions = np.concatenate(all_positions_lst) - self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort() - self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) - - self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs] - self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum() - return pcp_tokens[: self.num_reqs], positions + return pcp_tokens[: self.num_reqs], positions def get_logits_indices( self, @@ -539,7 +550,6 @@ class PCPManager: num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32) for i, req_id in enumerate(input_batch.req_ids): num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens_pcp_full) req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full) cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) self.query_start_loc_pcp_full.np[0] = 0 @@ -567,7 +577,6 @@ class PCPManager: cu_num_tokens_pcp_full, num_spec_tokens, ) - self.query_lens_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu() self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full) self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full @@ -719,15 +728,10 @@ class PCPManager: if self.pcp_world_size > 1 and self.pcp_use_hybrid_attn: assert self.num_scheduled_tokens_padded is not None total_num_scheduled_tokens = self.num_scheduled_tokens_padded.sum() - query_lens_new = ( - self.query_lens_pcp_full.cpu[:num_reqs] - if self.pcp_world_size > 1 and self.speculative_config - else query_lens - ) - num_decodes = (query_lens_new <= self.decode_threshold).sum().item() num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None + ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded] if self.pcp_world_size * self.dcp_world_size > 1: assert num_scheduled_tokens is not None decode_context_lens = ( @@ -753,7 +757,6 @@ class PCPManager: self.vllm_config.parallel_config.cp_kv_cache_interleave_size, ) ) - ori_query_lens_cpu = None if self.decode_threshold > 1: num_computed_tokens_of_pcp_dcp_list = [] if self.num_decode_reqs: @@ -781,7 +784,6 @@ class PCPManager: # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded] ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded] num_prefill_reqs = self.num_prefill_reqs num_decode_reqs = self.num_decode_reqs @@ -806,10 +808,9 @@ class PCPManager: num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(), pcp_unpad_mask=torch.from_numpy(pcp_unpad_mask), pcp_padded_tokens_fla=self.pcp_padded_tokens_fla, + query_lens_pcp_full_cpu=ori_query_lens_cpu, + max_query_len_pcp_full=ori_query_lens_cpu.max().item(), ) - if ori_query_lens_cpu is not None: - long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu - long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item() if self.pcp_world_size > 1: q_head_idx, q_tail_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] @@ -906,19 +907,18 @@ class PCPManager: "head_attn_nomask_seqlens": head_attn_nomask_seqlens, "tail_attn_nomask_seqlens": tail_attn_nomask_seqlens, } - if not self.pcp_use_hybrid_attn: - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ - :num_actual_tokens_pcp_padded - ] - else: - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ - : num_scheduled_tokens.sum() - num_decodes + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ + :num_actual_tokens_pcp_padded + ] + if self.pcp_use_hybrid_attn: + long_seq_metadata.pcp_exit_fa_scatter_idx = self.pcp_exit_fa_scatter_idx.gpu[ + : num_scheduled_tokens.sum() - self.num_decode_reqs ] long_seq_metadata.pcp_fa_query_idx = self.pcp_fa_query_idx[ - : num_actual_tokens_pcp_padded // self.pcp_world_size - num_decodes + : num_actual_tokens_pcp_padded // self.pcp_world_size - self.num_decode_reqs ] long_seq_metadata.pcp_enter_fa_restore_idx = self.pcp_enter_fa_restore_idx[ - : pcp_unpad_mask.sum() + num_decodes * (self.pcp_world_size - 1) + : pcp_unpad_mask.sum() + self.num_decode_reqs * (self.pcp_world_size - 1) ] long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor