diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index dfec5a3c..163bcfb8 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -225,21 +225,22 @@ class TestMtpProposer: mock_deps.runner.spec_decode_common_attn_metadata = MagicMock() mock_deps.runner.pcp_size = 2 mock_deps.runner.dcp_size = 1 - mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer( + mock_deps.runner.pcp_manager = MagicMock() + mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer( 32, dtype=torch.int32, pin_memory=False, device='cpu', ) - mock_deps.runner.input_ids_pcp_full.cpu = \ + mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \ torch.arange(32, dtype=torch.int32) - mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer( + mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer( 5, dtype=torch.int32, pin_memory=False, device='cpu', ) - mock_deps.runner.query_start_loc_pcp_full.cpu = \ + mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \ torch.tensor([0, 8, 16, 24, 32]) mock_deps.positions = torch.arange(16, dtype=torch.int32) mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16) diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py deleted file mode 100644 index 8ff26a6f..00000000 --- a/tests/ut/worker/test_model_runner_v1.py +++ /dev/null @@ -1,473 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -from unittest.mock import MagicMock - -import numpy as np -import pytest -import torch - -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner - - -@pytest.mark.parametrize( - "pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none", - [ - (1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False), - (1, 2, 3, [20, 30, 40], 1, False, 50, True), - (2, 1, 4, [5, 10, 40, 60], 2, False, 100, True), - (2, 1, 4, [5, 10, 40, 60], 2, True, 100, True), - (2, 1, 3, [5, 10, 15], 3, False, 50, True), - (2, 1, 3, [40, 50, 60], 0, False, 150, True), - ]) -def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, - num_decodes, use_mla, total_tokens, - expect_not_none): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = pcp_size - mock_runner.dcp_size = dcp_size - mock_runner.decode_threshold = 4 - mock_runner.pcp_rank = 0 - mock_runner.device = torch.device('cpu') - mock_runner.dtype = torch.float32 - - mock_runner.parallel_config = MagicMock() - mock_runner.parallel_config.cp_kv_cache_interleave_size = 64 - - mock_runner.vllm_config = MagicMock() - mock_runner.vllm_config.model_config = MagicMock() - mock_runner.vllm_config.model_config.use_mla = use_mla - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = num_reqs - mock_runner.speculative_config = None - - num_computed_tokens = [] - num_prompt_tokens = [] - num_tokens = [] - - for i in range(num_reqs): - if i < num_decodes: - num_computed_tokens.append(query_lens[i]) - num_prompt_tokens.append(query_lens[i] // 2) - num_tokens.append(query_lens[i]) - else: - num_computed_tokens.append(0) - num_prompt_tokens.append(query_lens[i]) - num_tokens.append(query_lens[i]) - - mock_runner.input_batch.num_computed_tokens_cpu = torch.tensor( - num_computed_tokens) - mock_runner.input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) - mock_runner.input_batch.num_tokens = torch.tensor(num_tokens) - - mock_runner.query_lens = torch.tensor(query_lens) - - mock_runner._get_cp_local_seq_lens = NPUModelRunner._get_cp_local_seq_lens.__get__( - mock_runner, NPUModelRunner) - - mock_runner.pcp_allgather_restore_idx = torch.arange(total_tokens * 2) - mock_runner.cp_kv_recover_idx_for_chunk = torch.arange(total_tokens) - - mock_runner.long_seq_metadata = None - mock_runner.num_actual_tokens_pcp_padded = 0 - mock_runner.kv_idx_names = {} - mock_runner.extra_long_seq_kwargs = {} - mock_runner.attn_mask = None - mock_runner.q_head_idx_tensor = None - mock_runner.q_tail_idx_tensor = None - mock_runner.q_full_idx = None - - method = NPUModelRunner._generate_pcp_metadata.__get__( - mock_runner, NPUModelRunner) - result = method(total_tokens) - - if not expect_not_none: - assert result is None, f"Expected to return None, but got {type(result)}" - else: - assert result is not None, "Expected to return a metadata object, but got None." - - assert hasattr(result, 'num_actual_tokens_pcp_padded') - assert hasattr(result, 'num_computed_tokens_of_pcp_dcp') - - if pcp_size > 1: - assert hasattr(result, 'pcp_allgather_restore_idx') - - has_prefill_requests = (num_reqs - num_decodes) > 0 - if has_prefill_requests: - assert hasattr(result, 'q_head_idx_tensor') - assert hasattr(result, 'q_tail_idx_tensor') - assert hasattr(result, 'q_full_idx') - assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor') - assert hasattr(result, 'kv_with_q_head_mask_idx_tensor') - assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor') - assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor') - assert hasattr(result, 'attn_mask_seqlens') - assert hasattr(result, 'head_attn_nomask_seqlens') - assert hasattr(result, 'tail_attn_nomask_seqlens') - - if hasattr(result, 'pcp_prefill_mask' - ) and result.pcp_prefill_mask is not None: - if use_mla: - assert result.pcp_prefill_mask.shape == (512, 512) - else: - assert result.pcp_prefill_mask.shape == (2048, 2048) - else: - if hasattr(result, 'pcp_prefill_mask'): - if result.pcp_prefill_mask is not None: - if use_mla: - assert result.pcp_prefill_mask.shape == (512, 512) - else: - assert result.pcp_prefill_mask.shape == (2048, - 2048) - - -def test_generate_pcp_metadata_edge_cases(): - mock_runner = MagicMock() - mock_runner.pcp_size = 2 - mock_runner.dcp_size = 1 - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = 0 - mock_runner.query_lens = torch.tensor([10, 20, 30]) - - assert (mock_runner.input_batch.num_reqs - or mock_runner.query_lens.size(0)) == 3 - - mock_runner.input_batch.num_reqs = 100 - mock_runner.query_lens = torch.ones(100) * 1000 - - for rank in [0, 1]: - mock_runner.pcp_rank = rank - q_head_chunk_id = rank - q_tail_chunk_id = 2 * 2 - 1 - rank - assert q_head_chunk_id == rank - assert q_tail_chunk_id == 3 - rank - - -def test_pcp_allgather_restore_idx_slicing(): - mock_runner = MagicMock() - mock_runner.pcp_size = 2 - mock_runner.pcp_allgather_restore_idx = torch.arange(1000) - - total_num_scheduled_tokens = 200 - num_actual_tokens_pcp_padded = total_num_scheduled_tokens * 2 - - expected_slice = mock_runner.pcp_allgather_restore_idx[: - num_actual_tokens_pcp_padded] - assert len(expected_slice) == 400 - assert expected_slice[0] == 0 - assert expected_slice[-1] == 399 - - -@pytest.mark.parametrize( - "tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \ - "pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens", - [ - # Case 1: prefill only - ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]), - - # Case 2: mix prefill and decode (with spec decode) - ([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]), - - # Case 3: request which need to be padded - ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]), - - # Case 4: single request - ([10], 1, [0], [10], 4, 0, 1, [4]), - ]) -def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, - num_prompt_tokens, pcp_size, pcp_rank, - decode_threshold, expected_pcp_tokens): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = pcp_size - mock_runner.pcp_rank = pcp_rank - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = num_reqs - mock_runner.input_batch.num_computed_tokens_cpu = np.array( - num_computed_tokens, dtype=np.int32) - mock_runner.input_batch.num_prompt_tokens = np.array(num_prompt_tokens, - dtype=np.int32) - - mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) - - mock_runner.num_pcp_pads = [0] * num_reqs - mock_runner.arange_np = np.arange(10000) - mock_runner.decode_threshold = decode_threshold - - mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( - mock_runner, NPUModelRunner) - - pcp_tokens_result, positions_result, unpad_mask_result = mock_runner._update_tokens_for_pcp( - tokens) - - assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ - f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" - - total_pcp_tokens: int = np.sum(pcp_tokens_result) - assert positions_result.shape == (total_pcp_tokens,), \ - f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}" - - padded_tokens = [ - (t + 2 * pcp_size - 1) // (2 * pcp_size) * - (2 * pcp_size) if num_computed_tokens[i] == 0 else t * pcp_size - for i, t in enumerate(tokens) - ] - total_padded_tokens: int = np.sum(padded_tokens) - assert unpad_mask_result.shape[0] == total_padded_tokens, \ - f"unpad_mask size mismatch: expected {total_padded_tokens}, got {unpad_mask_result.shape[0]}" - - -def test_update_tokens_for_pcp_with_padding(): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = 4 - mock_runner.pcp_rank = 0 - - mock_runner.arange_np = np.arange(10000) - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = 3 - mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0, 0], - dtype=np.int32) - mock_runner.input_batch.num_prompt_tokens = np.array([5, 9, 13], - dtype=np.int32) - - mock_runner.num_pcp_pads = [0, 0, 0] - mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) - mock_runner.decode_threshold = 1 - - mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( - mock_runner, NPUModelRunner) - - tokens = [5, 9, 13] - - pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp( - tokens) - - expected_pcp_tokens = [2, 4, 4] - assert np.array_equal(pcp_tokens, expected_pcp_tokens), \ - f"Expected {expected_pcp_tokens}, got {pcp_tokens}" - - expected_pads = [3, 7, 3] - assert np.array_equal(mock_runner.num_pcp_pads, expected_pads), \ - f"Expected padding {expected_pads}, got {mock_runner.num_pcp_pads}" - - -def test_update_tokens_for_pcp_unpad_mask(): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = 4 - mock_runner.pcp_rank = 0 - - mock_runner.arange_np = np.arange(10000) - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = 2 - mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0], - dtype=np.int32) - mock_runner.input_batch.num_prompt_tokens = np.array([5, 7], - dtype=np.int32) - - mock_runner.num_pcp_pads = [0, 0] - mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) - mock_runner.decode_threshold = 1 - - mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( - mock_runner, NPUModelRunner) - - tokens = [5, 7] - - pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp( - tokens) - - assert unpad_mask.dtype == torch.bool, \ - f"unpad_mask should be bool, got {unpad_mask.dtype}" - - padded_tokens = [8, 8] - expected_length = sum(padded_tokens) - assert unpad_mask.shape[0] == expected_length, \ - f"unpad_mask length mismatch: expected {expected_length}, got {unpad_mask.shape[0]}" - - expected_mask = [True] * 5 + [False] * 3 + [True] * 7 + [False] * 1 - actual_mask = unpad_mask.numpy().tolist() - assert actual_mask == expected_mask, \ - f"unpad_mask incorrect. Expected {expected_mask}, got {actual_mask}" - - -# yapf: disable -@pytest.mark.parametrize( - "seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target", - [ - # without pcp and dcp - (torch.tensor([1, 2, 128, 129]), 1, 1, 1, - torch.tensor([[[1]], [[2]], [[128]], [[129]]])), - # pcp - (torch.tensor([1, 2, 128, 129]), 2, 1, 1, - torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])), - # dcp - (torch.tensor([1, 2, 128, 129]), 1, 2, 1, - torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])), - # pcp + dcp - (torch.tensor([1, 2, 128, 129]), 2, 2, 1, - torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]], - [[32, 32], [32, 32]], [[33, 32], [32, 32]]])), - # specify interleave_size - (torch.tensor([1, 2, 128, 129]), 2, 1, 2, - torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])), - (torch.tensor([1, 2, 128, 129]), 2, 1, 128, - torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])), - (torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128, - torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]], - [[128, 0], [0, 0]], [[128, 1], [0, 0]], - [[128, 128], [0, 0]], [[128, 128], [1, 0]]])), - ] -) -# yapf: enable -def test_get_cp_local_seq_lens( - seq_lens, - pcp_world_size, - dcp_world_size, - cp_kv_cache_interleave_size, - target, -): - mock_runner = MagicMock(spec=NPUModelRunner) - ret = NPUModelRunner._get_cp_local_seq_lens(mock_runner, seq_lens, - pcp_world_size, dcp_world_size, - cp_kv_cache_interleave_size) - assert torch.equal(ret, target) - - -@pytest.fixture -def pcp_mtp_mock_runner(): - # set up pcp & mtp related buffers - max_num_reqs = 4 - max_model_len = 4096 - max_num_tokens = 4096 - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.device = 'cpu' - mock_runner.pin_memory = False - - # Init model_runner pcp_mtp related buffers - mock_runner.query_start_loc_pcp_full = NPUModelRunner._make_buffer( - mock_runner, max_num_reqs + 1, dtype=torch.int32) - - positions_buff = torch.zeros(max_num_tokens, - dtype=torch.int64, - device="cpu") - mock_runner.positions_pcp_full = positions_buff - mock_runner.positions_pcp_full_np = positions_buff.numpy() - - mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer( - mock_runner, max_num_tokens, dtype=torch.int32) - mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer( - mock_runner, max_num_reqs, dtype=torch.int32) - mock_runner.decode_threshold = 1 - - mock_runner.arange_np = np.arange(max_model_len) - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_computed_tokens_cpu = \ - np.zeros(max_num_reqs, dtype=np.int32) - token_ids_cpu_tensor = torch.zeros( - (max_num_reqs, max_model_len), - device="cpu", - dtype=torch.int32, - ) - mock_runner.input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor - mock_runner.input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy() - return mock_runner - - -# yapf: disable -@pytest.mark.parametrize( - "req_ids, num_computed_tokens," \ - "token_ids_tensor_list," \ - "num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \ - "target_input_ids_pcp_full, target_query_start_loc_pcp_full", - [ - # prefill - ( - ['0'], np.array([0]), - [torch.tensor([0, 671, 6102, 294, 8760, 344])], - 1, 6, {'0': 6}, - torch.tensor([0, 671, 6102, 294, 8760, 344]), - torch.tensor([0, 6]) - ), - # decode - ( - ['0'], np.array([6]), - [torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])], - 1, 2, {'0': 2}, - torch.tensor([88907, 0]), - torch.tensor([0, 2]) - ), - # decode + prefill - ( - ['0', '1'], np.array([6, 0]), - [ - torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), - torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), - ], - 2, 12, {'0': 2, '1': 10}, - torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), - torch.tensor([0, 2, 12]) - ), - # decodes + prefills - ( - ['0', '1', '2', '3'], np.array([6, 8, 0, 0]), - [ - torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), - torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]), - torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]), - torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]), - ], - 4, 19, {'0': 2, '1': 2, '2': 8, '3': 7}, - torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907, - 0, 671, 5335, 1469, 7539, 305, 6397]), - torch.tensor([0, 2, 4, 12, 19]) - ), - ]) -# yapf: enable -def test_generate_pcp_mtp_input( - pcp_mtp_mock_runner, - req_ids, - num_computed_tokens, - token_ids_tensor_list, - num_reqs, - total_num_scheduled_tokens, - num_scheduled_tokens, - target_input_ids_pcp_full, - target_query_start_loc_pcp_full, -): - mock_runner = pcp_mtp_mock_runner - token_ids_cpu_tensor = mock_runner.input_batch.token_ids_cpu_tensor - - # Set input_batch - mock_runner.input_batch.req_ids = req_ids - mock_runner.input_batch.num_computed_tokens_cpu[:num_computed_tokens. - size] = num_computed_tokens - for i, token_ids_tensor in enumerate(token_ids_tensor_list): - token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor - - NPUModelRunner._generate_pcp_mtp_input(mock_runner, num_reqs, - total_num_scheduled_tokens, - num_scheduled_tokens) - assert torch.equal( - mock_runner.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], - target_input_ids_pcp_full) - assert torch.equal(mock_runner.query_start_loc_pcp_full.cpu[:num_reqs + 1], - target_query_start_loc_pcp_full) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py new file mode 100644 index 00000000..9f5863ab --- /dev/null +++ b/tests/ut/worker/test_pcp_manager.py @@ -0,0 +1,322 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch + +from vllm_ascend.worker.pcp_utils import PCPManager + + +@pytest.mark.parametrize( + "pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none", + [ + (1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False), + (1, 2, 3, [20, 30, 40], 1, False, 50, True), + (2, 1, 4, [5, 10, 40, 60], 2, False, 100, True), + (2, 1, 4, [5, 10, 40, 60], 2, True, 100, True), + (2, 1, 3, [5, 10, 15], 3, False, 50, True), + (2, 1, 3, [40, 50, 60], 0, False, 150, True), + ]) +def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, + num_decodes, use_mla, total_tokens, + expect_not_none): + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.model_config.use_mla = use_mla + vllm_config.parallel_config.cp_kv_cache_interleave_size = 64 + vllm_config.speculative_config.num_speculative_tokens = 0 + + pcp_manager = PCPManager(pcp_world_size=pcp_size, + pcp_rank=0, + dcp_world_size=dcp_size, + dcp_rank=0, + max_buffer_num_tokens=10000, + max_num_reqs=1000, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + input_batch = MagicMock() + input_batch.num_reqs = num_reqs + + num_computed_tokens = [] + num_prompt_tokens = [] + num_tokens = [] + + for i in range(num_reqs): + if i < num_decodes: + num_computed_tokens.append(query_lens[i]) + num_prompt_tokens.append(query_lens[i] // 2) + num_tokens.append(query_lens[i]) + else: + num_computed_tokens.append(0) + num_prompt_tokens.append(query_lens[i]) + num_tokens.append(query_lens[i]) + + input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens) + input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) + input_batch.num_tokens = torch.tensor(num_tokens) + + query_lens = torch.tensor(query_lens) + result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None, + input_batch) + + if not expect_not_none: + assert result is None, f"Expected to return None, but got {type(result)}" + else: + assert result is not None, "Expected to return a metadata object, but got None." + + assert hasattr(result, 'num_actual_tokens_pcp_padded') + assert hasattr(result, 'num_computed_tokens_of_pcp_dcp') + + if pcp_size > 1: + assert hasattr(result, 'pcp_allgather_restore_idx') + + has_prefill_requests = (num_reqs - num_decodes) > 0 + if has_prefill_requests: + assert hasattr(result, 'q_head_idx_tensor') + assert hasattr(result, 'q_tail_idx_tensor') + assert hasattr(result, 'q_full_idx') + assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor') + assert hasattr(result, 'kv_with_q_head_mask_idx_tensor') + assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor') + assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor') + assert hasattr(result, 'attn_mask_seqlens') + assert hasattr(result, 'head_attn_nomask_seqlens') + assert hasattr(result, 'tail_attn_nomask_seqlens') + + if hasattr(result, 'pcp_prefill_mask' + ) and result.pcp_prefill_mask is not None: + if use_mla: + assert result.pcp_prefill_mask.shape == (512, 512) + else: + assert result.pcp_prefill_mask.shape == (2048, 2048) + else: + if hasattr(result, 'pcp_prefill_mask'): + if result.pcp_prefill_mask is not None: + if use_mla: + assert result.pcp_prefill_mask.shape == (512, 512) + else: + assert result.pcp_prefill_mask.shape == (2048, + 2048) + + +@pytest.mark.parametrize( + "tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens", + [ + # Case 1: prefill only + ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]), + + # # Case 2: mix prefill and decode + ([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]), + + # # Case 3: request which need to be padded + ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]), + + # Case 4: single request + ([10], 1, [0], [10], 4, 0, [4]), + ]) +def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, + num_prompt_tokens, pcp_size, pcp_rank, + expected_pcp_tokens): + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.speculative_config.num_speculative_tokens = 0 + + pcp_manager = PCPManager(pcp_world_size=pcp_size, + pcp_rank=0, + dcp_world_size=1, + dcp_rank=0, + max_buffer_num_tokens=10000, + max_num_reqs=1000, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + input_batch = MagicMock() + input_batch.num_reqs = num_reqs + input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens, + dtype=np.int32) + input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32) + arange_np = np.arange(10000) + pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp( + np.array(tokens), arange_np, num_reqs, 1) + + assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ + f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" + + total_pcp_tokens: int = np.sum(pcp_tokens_result) + assert positions_result.shape == (total_pcp_tokens,), \ + f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}" + + +# yapf: disable +@pytest.mark.parametrize( + "seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target", + [ + # without pcp and dcp + (torch.tensor([1, 2, 128, 129]), 1, 1, 1, + torch.tensor([[[1]], [[2]], [[128]], [[129]]])), + # pcp + (torch.tensor([1, 2, 128, 129]), 2, 1, 1, + torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])), + # dcp + (torch.tensor([1, 2, 128, 129]), 1, 2, 1, + torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])), + # pcp + dcp + (torch.tensor([1, 2, 128, 129]), 2, 2, 1, + torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]], + [[32, 32], [32, 32]], [[33, 32], [32, 32]]])), + # specify interleave_size + (torch.tensor([1, 2, 128, 129]), 2, 1, 2, + torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])), + (torch.tensor([1, 2, 128, 129]), 2, 1, 128, + torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])), + (torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128, + torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]], + [[128, 0], [0, 0]], [[128, 1], [0, 0]], + [[128, 128], [0, 0]], [[128, 128], [1, 0]]])), + ] +) +# yapf: enable +def test_get_cp_local_seq_lens( + seq_lens, + pcp_world_size, + dcp_world_size, + cp_kv_cache_interleave_size, + target, +): + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.speculative_config.num_speculative_tokens = 0 + pcp_manager = PCPManager(pcp_world_size=pcp_world_size, + pcp_rank=0, + dcp_world_size=dcp_world_size, + dcp_rank=0, + max_buffer_num_tokens=10000, + max_num_reqs=1000, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size, + dcp_world_size, + cp_kv_cache_interleave_size) + assert torch.equal(ret, target) + + +# yapf: disable +@pytest.mark.parametrize( + "req_ids, num_computed_tokens," \ + "token_ids_tensor_list," \ + "num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \ + "target_input_ids_pcp_full, target_query_start_loc_pcp_full", + [ + # prefill + ( + ['0'], np.array([0]), + [torch.tensor([0, 671, 6102, 294, 8760, 344])], + 1, 6, {'0': 6}, + torch.tensor([0, 671, 6102, 294, 8760, 344]), + torch.tensor([0, 6]) + ), + # decode + ( + ['0'], np.array([6]), + [torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])], + 1, 2, {'0': 2}, + torch.tensor([88907, 0]), + torch.tensor([0, 2]) + ), + # decode + prefill + ( + ['0', '1'], np.array([6, 0]), + [ + torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), + torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), + ], + 2, 12, {'0': 2, '1': 10}, + torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), + torch.tensor([0, 2, 12]) + ), + # decodes + prefills + ( + ['0', '1', '2', '3'], np.array([6, 8, 0, 0]), + [ + torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), + torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]), + torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]), + torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]), + ], + 4, 19, {'0': 2, '1': 2, '2': 8, '3': 7}, + torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907, + 0, 671, 5335, 1469, 7539, 305, 6397]), + torch.tensor([0, 2, 4, 12, 19]) + ), + ]) +# yapf: enable +def test_generate_pcp_mtp_input( + req_ids, + num_computed_tokens, + token_ids_tensor_list, + num_reqs, + total_num_scheduled_tokens, + num_scheduled_tokens, + target_input_ids_pcp_full, + target_query_start_loc_pcp_full, +): + max_num_reqs = 4 + max_model_len = 4096 + max_num_tokens = 4096 + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.speculative_config.num_speculative_tokens = 1 + vllm_config.scheduler_config.max_num_seqs = max_num_reqs + vllm_config.scheduler_config.max_num_batched_tokens = max_model_len + pcp_manager = PCPManager(pcp_world_size=2, + pcp_rank=0, + dcp_world_size=1, + dcp_rank=0, + max_buffer_num_tokens=max_num_tokens, + max_num_reqs=max_num_reqs, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + arange_np = np.arange(max_model_len) + input_batch = MagicMock() + input_batch.num_computed_tokens_cpu = \ + np.zeros(max_num_reqs, dtype=np.int32) + token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + ) + input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor + input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy() + token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor + + # Set input_batch + input_batch.req_ids = req_ids + input_batch.num_computed_tokens_cpu[:num_computed_tokens. + size] = num_computed_tokens + for i, token_ids_tensor in enumerate(token_ids_tensor_list): + token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor + + pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, + num_scheduled_tokens, False, + input_batch, arange_np) + assert torch.equal( + pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], + target_input_ids_pcp_full) + assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1], + target_query_start_loc_pcp_full) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 19de161c..30c6396d 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -279,9 +279,9 @@ class EagleProposer(VllmEagleProposer): req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size > 1: long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.input_ids_pcp_full - query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full - query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu + input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu num_reqs = self.runner.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 84049d8a..51bc6325 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -179,7 +179,7 @@ class MtpProposer(EagleProposer): if self.pcp_size * self.dcp_size > 1: # update long_seq related params and flatten block_table common_attn_metadata.prefill_context_parallel_metadata = \ - self.runner.long_seq_metadata + self.runner.pcp_manager.long_seq_metadata common_attn_metadata.block_table_tensor = \ self.runner.input_batch.block_table[0].get_device_tensor()[ :num_reqs * self.decode_threshold] @@ -286,9 +286,9 @@ class MtpProposer(EagleProposer): req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size * self.dcp_size > 1: long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu - query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu - query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu + input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu num_reqs = self.runner.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] @@ -303,7 +303,7 @@ class MtpProposer(EagleProposer): # update pcp related params if self.pcp_size > 1: token_indices_to_sample = \ - query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1 + query_start_loc_pcp_full[1:num_reqs + 1] - 1 target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states @@ -751,8 +751,8 @@ class MtpProposer(EagleProposer): hidden_states = hidden_states[:num_tokens] hidden_states = get_pcp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, self.runner. - pcp_allgather_restore_idx[:hidden_states.shape[0]]) + hidden_states, 0, self.runner.pcp_manager. + pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]]) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -797,13 +797,13 @@ class MtpProposer(EagleProposer): # (_generate_pcp_mtp_input), and use updated slot_indices # to get corresponding slot_mapping in each step. num_reject_tokens = torch.tensor( - self.runner.cu_num_tokens_pcp_full, + self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to( self.device) - ori_last_token_indices - 1 num_accept_tokens = \ query_lens_d.to(self.device) - num_reject_tokens ori_seq_len = attn_metadata_i.seq_lens - mtp_slot_mapping = self.runner.mtp_slot_pad + mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad # slot_mapping index base offset: # scheduled tokens + pre-allocated mtp tokens + accepted tokens @@ -889,7 +889,7 @@ class MtpProposer(EagleProposer): self.hidden_states[:hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: # update local seq_len and batch_seq_mask - num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens( + num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( ori_seq_len + step + 1, self.pcp_size, self.dcp_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b4ddf436..319c7e41 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,7 +24,7 @@ from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from dataclasses import dataclass from multiprocessing import Manager -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union import numpy as np import torch @@ -78,8 +78,7 @@ from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - AscendPrefillContextParallelMetadata) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata # yapf conflicts with isort for this block # yapf: disable from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, @@ -109,6 +108,7 @@ from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, lmhead_tp_enable, maybe_trans_nz, set_weight_prefetch_method) from vllm_ascend.worker.npu_input_batch import NPUInputBatch +from vllm_ascend.worker.pcp_utils import PCPManager from vllm_ascend.ascend_forward_context import ( # isort: skip MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, @@ -202,6 +202,26 @@ class NPUModelRunner(GPUModelRunner): self.pcp_rank = 0 if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs + max_buffer_num_tokens = self.max_num_tokens + if self.pcp_size * self.dcp_size > 1: + max_buffer_num_tokens = (self.max_num_tokens + + self.max_num_reqs * 2 * self.pcp_size) + self.pcp_manager = PCPManager( + self.pcp_size, + self.pcp_rank, + self.dcp_size, + self.dcp_rank, + max_buffer_num_tokens, + self.max_num_reqs, + self.device, + self.vllm_config, + self.pin_memory, + ) + # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this + self.input_ids = self._make_buffer(max_buffer_num_tokens, + dtype=torch.int32) + self.positions = self._make_buffer(max_buffer_num_tokens, + dtype=torch.int64) self.sampler = AscendSampler() self.attn_mask = None self.attn_state = None @@ -262,32 +282,6 @@ class NPUModelRunner(GPUModelRunner): set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, self.uniform_decode_query_len) set_mc2_mask(vllm_config, self.device) - self.pcp_allgather_restore_idx = torch.zeros( - self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ - [] for _ in range(self.pcp_size) - ] - - self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32) - self.pcp_padded_slot_mapping = torch.zeros( - self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.num_actual_tokens_pcp_padded = 0 - if self.speculative_config and self.pcp_size * self.dcp_size > 1: - self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.query_start_loc_pcp_full = self._make_buffer( - self.max_num_reqs + 1, dtype=torch.int32) - self.positions_pcp_full = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=True) - self.positions_pcp_full_np = self.positions_pcp_full.numpy() - self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) self.decode_threshold = 1 + ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0) @@ -359,6 +353,7 @@ class NPUModelRunner(GPUModelRunner): # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: IntermediateTensors | None = None self.reorder_batch_threshold: int | None = None + self.long_seq_metadata = None def _init_device_properties(self) -> None: self.num_sms = None @@ -508,49 +503,6 @@ class NPUModelRunner(GPUModelRunner): return self.attn_mask_builder.get_mla_mask(self.dtype) return self.attn_mask_builder.get_splitfuse_attn_mask() - def generate_kv_idx(self, scheduler_output): - if not self.pcp_size > 1: - return - self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)] - - for i, req_id in enumerate(self.input_batch.req_ids): - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - is_prefill = self.input_batch.num_computed_tokens_cpu[ - i] < self.input_batch.num_prompt_tokens[i] - if is_prefill: - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_tokens, - 2 * self.pcp_size) * (2 * self.pcp_size) - full_indices = list( - range(self.max_num_tokens * self.pcp_size * self.dcp_size + - self.pcp_size * self.dcp_size * self.max_num_reqs)) - chunk_size = num_cp_padded_scheduled_tokens // (2 * - self.pcp_size) - num_added_recover_tokens = len( - self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size - for rank in range(self.pcp_size): - self.cp_kv_recover_idx_for_chunk[rank].extend( - full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * - chunk_size + num_added_recover_tokens]) - self.cp_kv_recover_idx_for_chunk[rank].extend( - full_indices[num_cp_padded_scheduled_tokens - - (rank + 1) * chunk_size + - num_added_recover_tokens: - num_cp_padded_scheduled_tokens - - rank * chunk_size + - num_added_recover_tokens]) - - cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate( - self.cp_kv_recover_idx_for_chunk)).to(device=self.device) - cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( - torch.float32).argsort().to(torch.int32) - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -574,43 +526,70 @@ class NPUModelRunner(GPUModelRunner): req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - _, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - positions_np = np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - ) - - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) - - total_num_pcp_pads = 0 - if self.pcp_size > 1: - if not self.vllm_config.model_config.use_mla: - self.generate_kv_idx(scheduler_output) - tokens_before_update = tokens.copy() - tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( - tokens) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - total_num_pcp_pads = torch.sum(self.num_pcp_pads[:num_reqs]).item() - else: - position_pcp, pcp_unpad_mask = None, None - self.num_pcp_pads[:num_reqs] = 0 - - max_num_scheduled_tokens = max(tokens) if not scheduler_output.scheduled_spec_decode_tokens: num_valid_tokens = np.array(tokens, dtype=np.int32) else: num_valid_tokens = np.array([ num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for num_tokens, i in zip((tokens_before_update if self. - pcp_size > 1 else tokens), req_ids) + for num_tokens, i in zip(tokens, req_ids) ], dtype=np.int32) + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_state = attn_state # type: ignore + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + self.attn_mask = self._make_attention_mask(attn_state) + + # Get positions. + positions_np = self.positions.np[:total_num_scheduled_tokens] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + self.input_batch.block_table.compute_slot_mapping( + req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) + # for pcp, prefill mtp should use origin scheduleroutput , + if self.speculative_config and self.pcp_size * self.dcp_size > 1: + self.pcp_manager.generate_pcp_mtp_input( + num_reqs, total_num_scheduled_tokens, + scheduler_output.num_scheduled_tokens, with_prefill, + self.input_batch, self.arange_np, req_indices, positions_np, + cu_num_tokens) + + if self.pcp_size > 1: + if not self.vllm_config.model_config.use_mla: + self.pcp_manager.generate_kv_idx(scheduler_output, + self.input_batch) + num_scheduled_tokens[: + num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs], + self.arange_np, + self.input_batch.num_reqs, + self.reorder_batch_threshold, + ) + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange( + num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + position_pcp[:total_num_scheduled_tokens], + out=positions_np, + ) + max_num_scheduled_tokens = max(tokens) if (self.use_aclgraph and total_num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Add padding to the batch size. @@ -627,17 +606,6 @@ class NPUModelRunner(GPUModelRunner): else: # Eager mode. num_input_tokens = total_num_scheduled_tokens - - # Get the attention state. - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - self.attn_state = attn_state # type: ignore - - # Determine if it's a splitfuse batch - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - self.query_lens = torch.from_numpy(num_scheduled_tokens) # Get info across DP ranks. @@ -646,7 +614,7 @@ class NPUModelRunner(GPUModelRunner): (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp(num_input_tokens, with_prefill) - + self.with_prefill = with_prefill # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens # We should consider removing maybe_padded_num_tokens later num_input_tokens = maybe_padded_num_tokens @@ -655,24 +623,8 @@ class NPUModelRunner(GPUModelRunner): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - if self.pcp_size > 1: - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - position_pcp[:total_num_scheduled_tokens], - out=positions_np) - else: - self.positions.np[:total_num_scheduled_tokens] = positions_np - + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self._calc_mrope_positions(scheduler_output) @@ -766,21 +718,11 @@ class NPUModelRunner(GPUModelRunner): self.seq_lens.gpu[num_reqs:].fill_(0) - self.query_lens = torch.from_numpy(num_scheduled_tokens) - # Copy the tensors to the NPU. self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, cu_num_tokens) self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions.copy_to_gpu() - - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - self.attn_mask = self._make_attention_mask(attn_state) - self.attn_state = attn_state # type: ignore - - self.with_prefill = with_prefill - self.num_tokens_across_dp = num_tokens_across_dp attn_metadata: dict[str, Any] = {} # Record the index of requests that should not be sampled, @@ -914,9 +856,8 @@ class NPUModelRunner(GPUModelRunner): # TODO: Support prompt logprobs. spec_decode_metadata = None if self.pcp_size * self.dcp_size > 1: - logits_indices = torch.from_numpy( - cu_num_tokens - ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1 + logits_indices = self.pcp_manager.get_logits_indices( + cu_num_tokens, num_reqs) logits_indices = logits_indices.pin_memory().to( self.device, non_blocking=True) else: @@ -938,8 +879,10 @@ class NPUModelRunner(GPUModelRunner): >= self.input_batch.num_prompt_tokens[req_idx]) else -1) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens, - self.num_pcp_pads[:num_reqs].numpy()) + num_draft_tokens, + cu_num_tokens, + num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs] + if self.pcp_size > 1 else None) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). @@ -961,23 +904,10 @@ class NPUModelRunner(GPUModelRunner): self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - if self.speculative_config and self.pcp_size * self.dcp_size > 1: - self._generate_pcp_mtp_input( - num_reqs, scheduler_output.total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens, with_prefill, - req_indices, positions_np, cu_num_tokens) - - long_seq_metadata = self._generate_pcp_metadata( - total_num_scheduled_tokens) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - # NOTE: This is strange, why did we use total_num_scheduled_tokens before? - slot_mapping_size = (total_num_scheduled_tokens - if self.pcp_size == 1 else - total_num_scheduled_tokens * self.pcp_size - - total_num_pcp_pads) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to @@ -993,30 +923,30 @@ class NPUModelRunner(GPUModelRunner): device=self.device, ) else: + maybe_pcp_full_tokens = ( + num_input_tokens if self.pcp_size == 1 else + total_num_scheduled_tokens * self.pcp_size - + sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])) blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() - blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0) - if self.pcp_size > 1: - slot_mapping_for_pcp = blk_table.slot_mapping.gpu[: - long_seq_metadata - . - num_actual_tokens_pcp_padded] - slot_mapping_for_pcp[slot_mapping_size:].fill_(-1) - assert pcp_unpad_mask is not None - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: - pcp_unpad_mask - . - shape[ - 0]] - pcp_padded_slot_mapping.fill_(-1) - pcp_padded_slot_mapping[ - pcp_unpad_mask] = slot_mapping_for_pcp[: - slot_mapping_size] - slot_mapping_for_pcp[:long_seq_metadata. - num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping - blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \ - slot_mapping_for_pcp + slot_mapping = blk_table.slot_mapping.gpu[: + maybe_pcp_full_tokens] + if self.pcp_size * self.dcp_size == 1: + slot_mapping[ + total_num_scheduled_tokens:num_input_tokens].fill_(-1) slot_mapping = blk_table.slot_mapping.gpu + if self.pcp_size * self.dcp_size > 1: + self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( + total_num_scheduled_tokens, self.query_lens, + self.attn_mask, self.input_batch) + blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) + slot_mapping = slot_mapping[:maybe_pcp_full_tokens] + slot_mapping = self.pcp_manager.get_padded_slot_mapping( + total_num_scheduled_tokens, + slot_mapping, + ) + blk_table.slot_mapping.gpu[:self.pcp_manager. + num_actual_tokens_pcp_padded] = slot_mapping # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs # has been split to multiple parts, and there are 3 parts that is related to this @@ -1055,7 +985,7 @@ class NPUModelRunner(GPUModelRunner): seq_lens_cpu=self.seq_lens.cpu[:num_reqs], seq_lens=self.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, - num_actual_tokens=slot_mapping_size, + num_actual_tokens=total_num_scheduled_tokens, num_input_tokens=num_input_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, # TODO: change this to the right block table for linear attn @@ -1069,8 +999,9 @@ class NPUModelRunner(GPUModelRunner): attn_state=self.attn_state, max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=long_seq_metadata, - max_seq_len=0) + prefill_context_parallel_metadata=self.long_seq_metadata, + max_seq_len=0, + ) if self.speculative_config and self.pcp_size * self.dcp_size > 1: # For pcp + spec decode, we flatten block_table @@ -1080,8 +1011,10 @@ class NPUModelRunner(GPUModelRunner): # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs] - ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs] + ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[: + num_reqs] + ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[: + num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() num_decode_reqs = num_reqs - num_prefill_reqs @@ -1097,13 +1030,17 @@ class NPUModelRunner(GPUModelRunner): ori_query_lens[:num_decode_reqs], dim=0)) common_attn_metadata.block_table_tensor = \ blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] - long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu + assert self.long_seq_metadata is not None + self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu + if 'pad_size' in locals() and pad_size > 0: ori_query_lens_cpu[-pad_size:] = \ torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) - long_seq_metadata.max_query_len_pcp_full = \ + self.long_seq_metadata.max_query_len_pcp_full = \ ori_query_lens_cpu.max().item() + + if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata = common_attn_metadata @@ -1193,19 +1130,12 @@ class NPUModelRunner(GPUModelRunner): pad_size = get_forward_context().pad_size if pad_size > 0: hidden_states = hidden_states[:-pad_size, :] - - if self.pcp_size > 1: - hidden_states = get_pcp_group().all_gather( - hidden_states[:self.num_actual_tokens_pcp_padded // - self.pcp_size], 0) - hidden_states = torch.index_select( - hidden_states, 0, - self.pcp_allgather_restore_idx[:hidden_states.shape[0]]) - return hidden_states + return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( + hidden_states) def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): - if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens): + if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): @@ -1231,7 +1161,7 @@ class NPUModelRunner(GPUModelRunner): self, num_draft_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray, - num_pcp_pads: np.ndarray, + num_pcp_pads: np.ndarray | None, ) -> SpecDecodeMetadata: # Inputs: # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] @@ -1846,7 +1776,9 @@ class NPUModelRunner(GPUModelRunner): self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) - long_seq_metadata = self._generate_pcp_metadata(num_tokens) + long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( + num_tokens, self.query_lens, self.attn_mask, + self.input_batch) if long_seq_metadata is not None: pcp_world_size = get_pcp_group().world_size dcp_world_size = get_dcp_group().world_size @@ -2890,365 +2822,6 @@ class NPUModelRunner(GPUModelRunner): parent_module_name): super().capture_model() - def _update_tokens_for_pcp(self, tokens): - num_reqs = self.input_batch.num_reqs - tokens = np.array(tokens, dtype=np.int32) - num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum() - num_decode_tokens = sum(tokens[:num_decode_reqs]) - num_padded_scheduled_tokens = np.ceil( - tokens / - (2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size) - num_padded_scheduled_tokens[:num_decode_reqs] = ( - tokens[:num_decode_reqs] * self.pcp_size) - self.num_pcp_pads[:num_reqs] = torch.tensor( - num_padded_scheduled_tokens - tokens) - cu_padded_tokens, pcp_padded_arange = \ - self._get_cumsum_and_arange(num_padded_scheduled_tokens) - unpad_mask = torch.from_numpy( - pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens)) - unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size] - unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size]) - unpad_mask_decode[:, 0] = True - unpad_mask_decode[:, 1:] = False - - pcp_tokens = num_padded_scheduled_tokens // self.pcp_size - pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) - pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] - _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) - _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) - pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, - pcp_tokens) - - def get_current_rank_positions(cu_tokens, rank): - positions_start_loc = np.zeros_like(cu_tokens) - positions_start_loc[1:] = cu_tokens[:-1] - positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) - head_start_loc = positions_start_loc + rank * pcp_chunk_sizes - tail_start_loc = positions_start_loc + \ - (2 * self.pcp_size - rank - 1) * pcp_chunk_sizes - positions[pcp_head_chunk_mask] = pcp_chunk_arange + \ - np.repeat(head_start_loc, pcp_chunk_sizes) - # Decode reqs do not have tail chunks. - positions[~pcp_head_chunk_mask] = \ - pcp_chunk_arange[num_decode_tokens:] + \ - np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] - return positions - - positions = get_current_rank_positions( - np.zeros(num_reqs, dtype=np.int32), self.pcp_rank) - # Decode tokens are duplicate and their positions always be 0. - if num_decode_reqs > 0: - positions[:num_decode_tokens] = self._get_cumsum_and_arange( - tokens[:num_decode_reqs])[1] - - all_positions = [ - get_current_rank_positions(cu_padded_tokens, rank_i) - for rank_i in range(self.pcp_size) - ] - all_positions_tensor = torch.from_numpy(np.concatenate(all_positions)) - self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_( - all_positions_tensor.float().argsort().long(), non_blocking=True) - return pcp_tokens, positions, unpad_mask - - def _get_cp_local_seq_lens( - self, - seq_lens: torch.Tensor, - pcp_world_size: int = 1, - dcp_world_size: int = 1, - cp_kv_cache_interleave_size: int = 1, - ) -> torch.Tensor: - """While using pcp or dcp, kv_cache size stored on each rank may be different, - use this function to calculate split decode seq_lens of each (p/d)cp rank. - """ - num_requests = seq_lens.size(0) - total_world_size = pcp_world_size * dcp_world_size - seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) - rank_offsets = (torch.arange(total_world_size, - dtype=torch.int32).unsqueeze(0).repeat( - num_requests, 1)) - base = (seq_lens_tiled // cp_kv_cache_interleave_size // - total_world_size * cp_kv_cache_interleave_size) - remainder = seq_lens_tiled - base * total_world_size - remainder = torch.clip( - remainder - rank_offsets * cp_kv_cache_interleave_size, - 0, - cp_kv_cache_interleave_size, - ) - dcp_local_seq_lens = (base + remainder).reshape( - [-1, pcp_world_size, dcp_world_size]) - return dcp_local_seq_lens - - def _generate_pcp_metadata(self, total_num_scheduled_tokens): - # In dummy run num_reqs == 0, update it from seq_lens - num_reqs = self.input_batch.num_reqs or self.query_lens.size(0) - query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \ - if self.pcp_size > 1 and self.speculative_config else self.query_lens - num_decodes = (query_lens <= self.decode_threshold).sum().item() - num_prefills = num_reqs - num_decodes - num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size - self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded - long_seq_metadata = None - if self.pcp_size * self.dcp_size > 1: - decode_context_lens = self.input_batch.num_tokens[:num_decodes] - prefill_context_lens = self.input_batch.num_computed_tokens_cpu[ - num_decodes:num_reqs] - context_lens = np.concatenate( - [decode_context_lens, prefill_context_lens]) - num_computed_tokens_of_pcp_dcp = torch.zeros( - [ - num_reqs * self.decode_threshold, self.pcp_size, - self.dcp_size - ], - dtype=torch.int32, - ) - # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape. - # Same as block_table, we flatten decode seq_lens to query_lens, - # and keep prefill seq_lens unchanged. - for decode_idx in range(self.decode_threshold): - num_computed_tokens_of_pcp_dcp[ - self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ - self._get_cp_local_seq_lens( - torch.tensor(context_lens) - decode_idx, - self.pcp_size, - self.dcp_size, - self.parallel_config.cp_kv_cache_interleave_size, - ) - if self.decode_threshold > 1: - num_computed_tokens_of_pcp_dcp_list = [] - if num_decodes: - num_decodes_flatten = \ - self.query_lens[:num_decodes].sum().item() - if self.query_lens[:num_decodes].min().item( - ) == self.decode_threshold: - decode_flatten_idx = list(range(num_decodes_flatten)) - else: - decode_flatten_idx = [] - for req_id in range(num_decodes): - offset = (req_id + 1) * self.decode_threshold - decode_flatten_idx += \ - list(range(offset - self.query_lens[req_id], offset)) - num_computed_tokens_of_pcp_dcp_list.append( - num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) - if num_prefills: - num_computed_tokens_of_pcp_dcp_list.append( - num_computed_tokens_of_pcp_dcp[ - (num_decodes + 1) * self.decode_threshold - - 1::self.decode_threshold]) - num_computed_tokens_of_pcp_dcp = torch.cat( - num_computed_tokens_of_pcp_dcp_list, dim=0) - long_seq_metadata = AscendPrefillContextParallelMetadata( - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. - numpy()) - if self.pcp_size > 1: - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.pcp_rank - q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank - for i, seq_len in enumerate(self.query_lens): - if i < num_decodes: - continue - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.pcp_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), - dtype=dtype, - device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to( - torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': - kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': - kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor( - [chunk_seqlens, chunk_seqlens], dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - pcp_prefill_mask = self.attn_mask - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'pcp_prefill_mask': pcp_prefill_mask - } - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[: - num_actual_tokens_pcp_padded] - long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk - long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor - long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor - long_seq_metadata.q_full_idx = self.q_full_idx - long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_mask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_mask_idx_tensor'] - long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ - 'attn_mask_seqlens'] - long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'head_attn_nomask_seqlens'] - long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'tail_attn_nomask_seqlens'] - long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ - 'pcp_prefill_mask'] - self.long_seq_metadata = long_seq_metadata - return long_seq_metadata - - def _generate_pcp_mtp_input( - self, - num_reqs: int, - total_num_scheduled_tokens: int, - num_scheduled_tokens: dict[str, int], - with_prefill: bool = True, - req_indices=None, - positions_np=None, - cu_num_tokens=None, - ): - """ - While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, - but mtp need to shift original input_ids before pcp splitting, - so we record original input_ids here. - """ - total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens - num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(self.input_batch.req_ids): - num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( - num_scheduled_tokens_pcp_full) - req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_pcp_full) - cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) - self.query_start_loc_pcp_full.np[0] = 0 - self.query_start_loc_pcp_full.np[1:num_reqs + - 1] = cu_num_tokens_pcp_full - self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) - cumsums_offsets_pcp_full = np.repeat( - cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, - num_scheduled_tokens_pcp_full) - arange_pcp_full = self.arange_np[: - total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full - positions_pcp_full_np = self.positions_pcp_full_np[: - total_num_scheduled_tokens_pcp_full] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], - arange_pcp_full, - out=positions_pcp_full_np) - token_indices_pcp_full = ( - positions_pcp_full_np + - req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices_pcp_full), - out=self.input_ids_pcp_full. - cpu[:total_num_scheduled_tokens_pcp_full]) - self.query_lens_pcp_full.copy_to_gpu() - self.query_start_loc_pcp_full.copy_to_gpu() - self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_( - self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full], - non_blocking=True, - ) - self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full - # For mtpx, pre-allocate mtp slot_mapping here - if self.decode_threshold > 2 and not with_prefill: - num_tokens_ori = sum(list(num_scheduled_tokens.values())) - num_tokens_mtp = \ - num_tokens_ori + num_reqs * (self.decode_threshold - 2) - num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size - req_indices_split = np.array_split(req_indices, - cu_num_tokens)[:num_reqs] - positions_split = np.array_split(positions_np, - cu_num_tokens)[:num_reqs] - for req_idx in range(num_reqs): - ori_req_indice = req_indices_split[req_idx] - ori_position = positions_split[req_idx] - req_indices_split[req_idx] = np.append( - ori_req_indice, - np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) - positions_split[req_idx] = np.append( - ori_position, - np.arange(ori_position[-1] + 1, - ori_position[-1] + self.decode_threshold - 1)) - req_indices_mtp = np.concatenate(req_indices_split) - positions_mtp = np.concatenate(positions_split) - self.input_batch.block_table.compute_slot_mapping( - req_indices_mtp, positions_mtp) - mtp_slot_ori = self.input_batch.block_table.block_tables[ - 0].slot_mapping.cpu[:num_tokens_mtp] - unpad_mask = np.repeat(False, num_tokens_mtp_pad) - unpad_mask[::self.pcp_size] = True - mtp_slot_pad = \ - torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) - mtp_slot_pad[unpad_mask] = mtp_slot_ori - self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) - def _prepare_multimodal_fields(self): """ Ensures specific multimodal tensors are on CPU. diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py new file mode 100644 index 00000000..b892f597 --- /dev/null +++ b/vllm_ascend/worker/pcp_utils.py @@ -0,0 +1,686 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/worker.py +# + +from typing import List + +import numpy as np +import torch +from vllm.config import VllmConfig +from vllm.utils.math_utils import cdiv +from vllm.v1.utils import CpuGpuBuffer + + +class PCPManager: + """ + Manager for Prefill Context Parallelism (PCP) metadata and buffers. + + This manager encapsulates all PCP-related buffers and logic so that the + ModelRunner can access them via `self.pcp_manager`. + """ + + def __init__( + self, + pcp_world_size: int, + pcp_rank: int, + dcp_world_size: int, + dcp_rank: int, + max_buffer_num_tokens: int, + max_num_reqs: int, + device: torch.device, + vllm_config: VllmConfig, + pin_memory: bool = False, + ) -> None: + self.pcp_world_size = pcp_world_size + self.pcp_world_rank = pcp_rank + self.dcp_world_size = dcp_world_size + self.dcp_world_rank = dcp_rank + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + ( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + self.vllm_config = vllm_config + self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + self.device = device + self.pcp_allgather_restore_idx = CpuGpuBuffer( + max_buffer_num_tokens, + dtype=torch.int64, + device=device, + pin_memory=pin_memory, + ) + self.pcp_padded_slot_mapping = torch.full( + (max_buffer_num_tokens, ), + fill_value=-1, + dtype=torch.int32, + device=device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), + device="cpu", + dtype=torch.int64) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_buffer_num_tokens, ), + device="cpu", + dtype=torch.bool, + ) + self.num_actual_tokens_pcp_padded = 0 + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ + [] for _ in range(self.pcp_world_size) + ] + self.full_indices = list( + range(self.max_num_tokens * self.pcp_world_size * + self.dcp_world_size + self.pcp_world_size * + self.dcp_world_size * self.max_num_reqs)) + if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1: + self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, + dtype=torch.int32, + device=device, + pin_memory=pin_memory) + self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1, + dtype=torch.int32, + device=device, + pin_memory=pin_memory) + self.positions_pcp_full = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=pin_memory) + self.positions_pcp_full_np = self.positions_pcp_full.numpy() + self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs, + dtype=torch.int32, + device=device, + pin_memory=pin_memory) + + def _get_cumsum_and_arange( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + cumsum_dtype: np.dtype | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def update_tokens_for_pcp( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + num_reqs: int, + reorder_batch_threshold: int | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Update token counts and positions for Prefill Context Parallelism (PCP). + + When using Prefill Context Parallelism, each request's prefill sequence is + split across multiple PCP ranks. The splitting strategy used here is the + "DualChunkSwap" style: each request's (padded) sequence is split into + 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved + head/tail pattern to balance load. + + This function: + - Computes how many tokens each request should be processed by the current + PCP rank (pcp_tokens). + - Computes the flattened positions of those tokens within the local + padded buffer (pcp_positions). + - Updates runner state arrays used to restore original order and mask out + padded tokens after allgather: + - self.num_pcp_pads_cpu: number of pads added per request + - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the + padded allgather buffer + - self.pcp_allgather_restore_idx: index array used to restore original + ordering after per-rank allgather and interleaving. + + Args: + num_scheduled_tokens: 1D numpy array of length num_reqs containing + the number of new tokens scheduled per request. + arange_np: 1D numpy array of length max_buffer_num_tokens used for + efficient batched arange operations. + num_reqs: Total number of requests in the batch. + reorder_batch_threshold: Threshold for decode vs prefill requests. + + Returns: + Tuple (pcp_tokens, pcp_positions): + - pcp_tokens: number of tokens per request that this PCP rank will + actually process (after splitting / replication). + - pcp_positions: flattened positions for those tokens on this rank, + used to build the positions buffer for the model. + + Example: + >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. + >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> Meanwhile, the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + + assert reorder_batch_threshold is not None, ( + "PCP depends on reorder batch to split decode and prefill requests." + ) + num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) + num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) + + # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). + # We first pad each request's token count up to that multiple. + num_padded_scheduled_tokens = np.ceil( + num_scheduled_tokens / (2 * self.pcp_world_size)).astype( + np.int32) * (2 * self.pcp_world_size) + + # PCP does not split decode requests. For decode requests, we instead + # duplicate the scheduled tokens across the pcp_world_size ranks. + num_padded_scheduled_tokens[:num_decode_reqs] = ( + num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size) + + # Record how many pads were added per request (padded - original). + self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens - + num_scheduled_tokens) + + # cu_padded_tokens: cumulative sum of padded token counts, + # pcp_padded_arange: per-request arange flattened for padded tokens. + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens, arange_np) + # Build the mask that marks which positions in the padded allgather buffer + # correspond to real (unpadded) tokens. + self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( + pcp_padded_arange < np.repeat(num_scheduled_tokens, + num_padded_scheduled_tokens)) + unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens * + self.pcp_world_size] + unpad_mask_decode = unpad_mask_decode.reshape( + [-1, self.pcp_world_size]) + unpad_mask_decode[:, 0] = True + unpad_mask_decode[:, 1:] = False + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + + # Compute per-request "chunk sizes" for the head/tail splitting. + # For prefill requests, we further split the pcp_tokens into two chunks + # (head and tail). For decode requests, the chunk equals pcp_tokens. + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + + # Build arange-style helpers for pcp tokens and chunk sizes: + # - pcp_arange gives indices repeated for each token in pcp_tokens + # - pcp_chunk_arange gives indices repeated for each position inside chunks + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) + _, pcp_chunk_arange = self._get_cumsum_and_arange( + pcp_chunk_sizes, arange_np) + + # Mask that marks whether a position belongs to the head chunk (True) + # or the tail chunk (False). For decode requests, tail chunk won't exist + # and is handled specially below. + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, + pcp_tokens) + + def get_current_rank_positions(positions_start_loc: int | np.ndarray, + rank: int): + """ + Compute flattened positions for the given rank with a given start + offset for each request (positions_start_loc). + + - For head chunks: start at positions_start_loc + rank * chunk_size. + - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - + 1) * chunk_size. + - For decode requests: no tail chunks; their positions are filled from the + contiguous (unpadded) `tokens` arange instead (handled after). + """ + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes) + # Fill head positions using chunk arange offset by head_start_loc. + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes) + # Fill tail positions. Note decode requests do not have tail chunks, + # so the tail filling is only for prefill positions. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]) + return positions + + positions = get_current_rank_positions(0, self.pcp_world_rank) + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + num_scheduled_tokens[:num_decode_reqs], arange_np)[1] + + # Build the restore index used after allgather. + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( + all_positions.argsort()) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + + return ( + pcp_tokens[:num_reqs], + positions, + ) + + def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): + return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - + self.num_pcp_pads_cpu_tensor[:num_reqs] - 1) + + def get_discard_request_mask( + self, + num_computed_tokens_cpu: np.ndarray, + num_scheduled_tokens: np.ndarray, + num_reqs: int, + num_tokens_np: np.ndarray, + ): + return (num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens * self.pcp_world_size - + self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np + + def get_padded_slot_mapping(self, num_tokens: int, + slot_mapping: torch.Tensor): + # After pcp allgather and restore, there are padded tokens in kv, + # so we need pad slotmapping for alignment. + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens * + self. + pcp_world_size] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * + self.pcp_world_size] + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + return pcp_padded_slot_mapping + + def get_restore_hidden_states( + self, + hidden_states: torch.Tensor, + ): + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + from vllm.distributed.parallel_state import get_pcp_group + hidden_states = get_pcp_group().all_gather( + hidden_states[:self.num_actual_tokens_pcp_padded // + self.pcp_world_size], + 0, + ) + restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states. + shape[0]] + return torch.index_select( + hidden_states, + 0, + restore_idx, + ) + + def generate_pcp_mtp_input( + self, + num_reqs: int, + total_num_scheduled_tokens: int, + num_scheduled_tokens: dict[str, int], + with_prefill: bool = True, + input_batch=None, + arange_np=None, + req_indices=None, + positions_np=None, + cu_num_tokens=None, + ): + """ + While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, + but mtp need to shift original input_ids before pcp splitting, + so we record original input_ids here. + """ + total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens + num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] + self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( + num_scheduled_tokens_pcp_full) + req_indices_pcp_full = np.repeat(arange_np[:num_reqs], + num_scheduled_tokens_pcp_full) + cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) + self.query_start_loc_pcp_full.np[0] = 0 + self.query_start_loc_pcp_full.np[1:num_reqs + + 1] = cu_num_tokens_pcp_full + self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) + cumsums_offsets_pcp_full = np.repeat( + cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, + num_scheduled_tokens_pcp_full) + arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full + positions_pcp_full_np = self.positions_pcp_full_np[: + total_num_scheduled_tokens_pcp_full] + np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], + arange_pcp_full, + out=positions_pcp_full_np) + token_indices_pcp_full = ( + positions_pcp_full_np + + req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]) + torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_pcp_full), + out=self.input_ids_pcp_full. + cpu[:total_num_scheduled_tokens_pcp_full]) + self.query_lens_pcp_full.copy_to_gpu() + self.query_start_loc_pcp_full.copy_to_gpu() + self.input_ids_pcp_full.copy_to_gpu( + total_num_scheduled_tokens_pcp_full) + self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full + # For mtpx, pre-allocate mtp slot_mapping here + if self.decode_threshold > 2 and not with_prefill: + num_tokens_ori = sum(list(num_scheduled_tokens.values())) + num_tokens_mtp = \ + num_tokens_ori + num_reqs * (self.decode_threshold - 2) + num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size + req_indices_split = np.array_split(req_indices, + cu_num_tokens)[:num_reqs] + positions_split = np.array_split(positions_np, + cu_num_tokens)[:num_reqs] + for req_idx in range(num_reqs): + ori_req_indice = req_indices_split[req_idx] + ori_position = positions_split[req_idx] + req_indices_split[req_idx] = np.append( + ori_req_indice, + np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) + positions_split[req_idx] = np.append( + ori_position, + np.arange(ori_position[-1] + 1, + ori_position[-1] + self.decode_threshold - 1)) + req_indices_mtp = np.concatenate(req_indices_split) + positions_mtp = np.concatenate(positions_split) + input_batch.block_table.compute_slot_mapping( + req_indices_mtp, positions_mtp) + mtp_slot_ori = input_batch.block_table.block_tables[ + 0].slot_mapping.cpu[:num_tokens_mtp] + unpad_mask = np.repeat(False, num_tokens_mtp_pad) + unpad_mask[::self.pcp_world_size] = True + mtp_slot_pad = \ + torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) + mtp_slot_pad[unpad_mask] = mtp_slot_ori + self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) + + def _get_cp_local_seq_lens( + self, + seq_lens: torch.Tensor, + pcp_world_size: int = 1, + dcp_world_size: int = 1, + cp_kv_cache_interleave_size: int = 1, + ) -> torch.Tensor: + """While using pcp or dcp, kv_cache size stored on each rank may be different, + use this function to calculate split decode seq_lens of each (p/d)cp rank. + """ + num_requests = seq_lens.size(0) + total_world_size = pcp_world_size * dcp_world_size + seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) + rank_offsets = (torch.arange(total_world_size, + dtype=torch.int32).unsqueeze(0).repeat( + num_requests, 1)) + base = (seq_lens_tiled // cp_kv_cache_interleave_size // + total_world_size * cp_kv_cache_interleave_size) + remainder = seq_lens_tiled - base * total_world_size + remainder = torch.clip( + remainder - rank_offsets * cp_kv_cache_interleave_size, + 0, + cp_kv_cache_interleave_size, + ) + dcp_local_seq_lens = (base + remainder).reshape( + [-1, pcp_world_size, dcp_world_size]) + return dcp_local_seq_lens + + def generate_kv_idx(self, scheduler_output, input_batch): + if not self.pcp_world_size > 1: + return + self.cp_kv_recover_idx_for_chunk = [[] + for _ in range(self.pcp_world_size) + ] + + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + is_prefill = input_batch.num_computed_tokens_cpu[ + i] < input_batch.num_prompt_tokens[i] + if is_prefill: + num_cp_padded_scheduled_tokens = cdiv( + num_scheduled_tokens, + 2 * self.pcp_world_size) * (2 * self.pcp_world_size) + chunk_size = num_cp_padded_scheduled_tokens // ( + 2 * self.pcp_world_size) + num_added_recover_tokens = len( + self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size + for rank in range(self.pcp_world_size): + self.cp_kv_recover_idx_for_chunk[rank].extend( + self.full_indices[rank * chunk_size + + num_added_recover_tokens:(rank + 1) * + chunk_size + + num_added_recover_tokens]) + self.cp_kv_recover_idx_for_chunk[rank].extend( + self.full_indices[num_cp_padded_scheduled_tokens - + (rank + 1) * chunk_size + + num_added_recover_tokens: + num_cp_padded_scheduled_tokens - + rank * chunk_size + + num_added_recover_tokens]) + + cp_kv_recover_idx_for_chunk = torch.from_numpy( + np.concatenate( + self.cp_kv_recover_idx_for_chunk)).to(device=self.device) + cp_kv_recover_idx_for_chunk.copy_(torch.tensor( + np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( + torch.float32).argsort().to(torch.int32) + + def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, + attn_mask, input_batch): + from vllm_ascend.attention.utils import \ + AscendPrefillContextParallelMetadata + num_reqs = input_batch.num_reqs or query_lens.size(0) + query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \ + if self.pcp_world_size > 1 and self.speculative_config else query_lens + num_decodes = (query_lens_new <= self.decode_threshold).sum().item() + num_prefills = num_reqs - num_decodes + num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size + self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded + long_seq_metadata = None + if self.pcp_world_size * self.dcp_world_size > 1: + decode_context_lens = input_batch.num_tokens[:num_decodes] + prefill_context_lens = input_batch.num_computed_tokens_cpu[ + num_decodes:num_reqs] + context_lens = np.concatenate( + [decode_context_lens, prefill_context_lens]) + num_computed_tokens_of_pcp_dcp = torch.zeros( + [ + num_reqs * self.decode_threshold, self.pcp_world_size, + self.dcp_world_size + ], + dtype=torch.int32, + ) + # For pcp + spec decode, we flatten seq_lens + # to avoid irregular spec_attn_mask shape. + # Same as block_table, we flatten decode seq_lens to query_lens, + # and keep prefill seq_lens unchanged. + for decode_idx in range(self.decode_threshold): + num_computed_tokens_of_pcp_dcp[ + self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ + self._get_cp_local_seq_lens( + torch.tensor(context_lens) - decode_idx, + self.pcp_world_size, + self.dcp_world_size, + self.vllm_config.parallel_config.cp_kv_cache_interleave_size, + ) + if self.decode_threshold > 1: + num_computed_tokens_of_pcp_dcp_list = [] + if num_decodes: + num_decodes_flatten = \ + query_lens[:num_decodes].sum().item() + if query_lens[:num_decodes].min().item( + ) == self.decode_threshold: + decode_flatten_idx = list(range(num_decodes_flatten)) + else: + decode_flatten_idx = [] + for req_id in range(num_decodes): + offset = (req_id + 1) * self.decode_threshold + decode_flatten_idx += \ + list(range(offset - query_lens[req_id], offset)) + num_computed_tokens_of_pcp_dcp_list.append( + num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) + if num_prefills: + num_computed_tokens_of_pcp_dcp_list.append( + num_computed_tokens_of_pcp_dcp[ + (num_decodes + 1) * self.decode_threshold - + 1::self.decode_threshold]) + num_computed_tokens_of_pcp_dcp = torch.cat( + num_computed_tokens_of_pcp_dcp_list, dim=0) + long_seq_metadata = AscendPrefillContextParallelMetadata( + num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, + num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. + numpy()) + if self.pcp_world_size > 1: + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.pcp_world_rank + q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank + for i, seq_len in enumerate(query_lens): + if i < num_decodes: + continue + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, kv_req_offset + + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, kv_req_offset + + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.pcp_world_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), + dtype=dtype, + device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to( + torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': + kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': + kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor( + [chunk_seqlens, chunk_seqlens], dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + pcp_prefill_mask = attn_mask + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'pcp_prefill_mask': pcp_prefill_mask + } + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: + num_actual_tokens_pcp_padded] + long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk + long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor + long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor + long_seq_metadata.q_full_idx = self.q_full_idx + long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_head_nomask_idx_tensor'] + long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_head_mask_idx_tensor'] + long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_tail_nomask_idx_tensor'] + long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_tail_mask_idx_tensor'] + long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ + 'attn_mask_seqlens'] + long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ + 'head_attn_nomask_seqlens'] + long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ + 'tail_attn_nomask_seqlens'] + long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ + 'pcp_prefill_mask'] + self.long_seq_metadata = long_seq_metadata + return long_seq_metadata